Commit ·
e7e3b60
1
Parent(s): 044ffe8
Create pipeline.py
Browse files- pipeline.py +21 -0
pipeline.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Any
|
| 2 |
+
|
| 3 |
+
from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PreTrainedPipeline():
|
| 7 |
+
def __init__(self, path: str):
|
| 8 |
+
cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX(
|
| 9 |
+
directory=path,
|
| 10 |
+
spe_filename="spe_32k_lc_en.model",
|
| 11 |
+
model_filename="punct_cap_seg_en.onnx",
|
| 12 |
+
config_filename="config.yaml",
|
| 13 |
+
)
|
| 14 |
+
self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg)
|
| 15 |
+
|
| 16 |
+
def __call__(self, data: str) -> List[Dict]:
|
| 17 |
+
# Use list to generate a batch of size 1
|
| 18 |
+
pred_texts: List[List[str]] = self._punctuator.infer([data])
|
| 19 |
+
# Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now.
|
| 20 |
+
outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}]
|
| 21 |
+
return outputs
|