Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Dict, Any | |
| from long_coref.coref.prediction import CorefPredictor | |
| from long_coref.coref.utils import ArchiveContent | |
| from allennlp.common.params import Params | |
| CHECKPOINT = "coref-spanbert-large-2021.03.10" | |
| class LongCorefPipeline: | |
| def __init__(self, path=""): | |
| archive_content = ArchiveContent( | |
| archive_dir=os.path.join(path, CHECKPOINT), | |
| weight_path=os.path.join(path, CHECKPOINT, "weights.th"), | |
| config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")), | |
| ) | |
| self.predictor = CorefPredictor.from_extracted_archive(archive_content) | |
| def __call__(self, data: str) -> Dict[str, Any]: | |
| """ | |
| data args: | |
| inputs (:obj: `str`) | |
| date (:obj: `str`) | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| # get inputs | |
| prediction = self.predictor.resolve_paragraphs(data.split("\n\n")) | |
| return prediction.to_dict() | |