empathy / src /infer.py
rhasan's picture
fixed module load
c031815
"""
FROM https://github.com/hasan-rakibul/UPLME/tree/main
"""
import torch
from transformers import AutoTokenizer
from src.paired_texts_modelling import LitPairedTextModel
_device = None
_model = None
_tokeniser = None
def load_model(ckpt_path: str):
global _model, _tokeniser, _device
plm_name = "roberta-base"
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model = LitPairedTextModel.load_from_checkpoint(ckpt_path).to(_device).eval()
_tokeniser = AutoTokenizer.from_pretrained(
plm_name,
use_fast=True,
add_prefix_space=False # the first word is tokenised differently if not a prefix space, but it might decrease performance, so False (09/24)
)
@torch.inference_mode()
def predict(essay: str, article: str) -> tuple[float, float]:
max_length = 512
toks = _tokeniser(
essay,
article,
truncation=True,
max_length=max_length,
return_tensors="pt"
).to(_device)
mean, var, _ = _model(toks)
return mean.item(), var.item()