my-tokenizer / inference.py
Max1798's picture
Upload folder using huggingface_hub
7ec131e verified
from transformers import Pipeline, PreTrainedTokenizer, AutoTokenizer
from typing import Dict, Union, List
import torch
class TokenizerPipeline(Pipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _sanitize_parameters(self, **kwargs):
# 处理传入参数:是否解码、padding等
preprocess_kwargs = {}
if "padding" in kwargs:
preprocess_kwargs["padding"] = kwargs["padding"]
if "truncation" in kwargs:
preprocess_kwargs["truncation"] = kwargs["truncation"]
postprocess_kwargs = {}
if "return_tokens" in kwargs:
postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs, **kwargs) -> Dict:
# 使用Tokenizer处理输入文本
return self.tokenizer(inputs, return_tensors="pt", **kwargs)
def _forward(self, inputs) -> Dict:
# 直接返回预处理结果(无模型推理)
return inputs
def postprocess(self, model_outputs, **kwargs) -> Dict:
# 转换输出为可读格式
input_ids = model_outputs["input_ids"][0]
if kwargs.get("return_tokens", True):
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
return {"tokens": tokens}
else:
return {"input_ids": input_ids.tolist()}
# 关键:创建并导出pipeline实例
tokenizer = AutoTokenizer.from_pretrained(".")
pipeline = TokenizerPipeline(tokenizer=tokenizer)
# 可选:添加类型提示供HF解析
def get_pipeline() -> Pipeline:
return pipeline