|
|
from transformers import Pipeline, PreTrainedTokenizer, AutoTokenizer |
|
|
from typing import Dict, Union, List |
|
|
import torch |
|
|
|
|
|
class TokenizerPipeline(Pipeline): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
|
|
|
preprocess_kwargs = {} |
|
|
if "padding" in kwargs: |
|
|
preprocess_kwargs["padding"] = kwargs["padding"] |
|
|
if "truncation" in kwargs: |
|
|
preprocess_kwargs["truncation"] = kwargs["truncation"] |
|
|
|
|
|
postprocess_kwargs = {} |
|
|
if "return_tokens" in kwargs: |
|
|
postprocess_kwargs["return_tokens"] = kwargs["return_tokens"] |
|
|
|
|
|
return preprocess_kwargs, {}, postprocess_kwargs |
|
|
|
|
|
def preprocess(self, inputs, **kwargs) -> Dict: |
|
|
|
|
|
return self.tokenizer(inputs, return_tensors="pt", **kwargs) |
|
|
|
|
|
def _forward(self, inputs) -> Dict: |
|
|
|
|
|
return inputs |
|
|
|
|
|
def postprocess(self, model_outputs, **kwargs) -> Dict: |
|
|
|
|
|
input_ids = model_outputs["input_ids"][0] |
|
|
|
|
|
if kwargs.get("return_tokens", True): |
|
|
tokens = self.tokenizer.convert_ids_to_tokens(input_ids) |
|
|
return {"tokens": tokens} |
|
|
else: |
|
|
return {"input_ids": input_ids.tolist()} |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
pipeline = TokenizerPipeline(tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
def get_pipeline() -> Pipeline: |
|
|
return pipeline |
|
|
|