|
|
""" |
|
|
FROM https://github.com/hasan-rakibul/UPLME/tree/main |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from src.paired_texts_modelling import LitPairedTextModel |
|
|
|
|
|
_device = None |
|
|
_model = None |
|
|
_tokeniser = None |
|
|
|
|
|
def load_model(ckpt_path: str): |
|
|
global _model, _tokeniser, _device |
|
|
plm_name = "roberta-base" |
|
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
_model = LitPairedTextModel.load_from_checkpoint(ckpt_path).to(_device).eval() |
|
|
_tokeniser = AutoTokenizer.from_pretrained( |
|
|
plm_name, |
|
|
use_fast=True, |
|
|
add_prefix_space=False |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def predict(essay: str, article: str) -> tuple[float, float]: |
|
|
max_length = 512 |
|
|
toks = _tokeniser( |
|
|
essay, |
|
|
article, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
return_tensors="pt" |
|
|
).to(_device) |
|
|
mean, var, _ = _model(toks) |
|
|
return mean.item(), var.item() |
|
|
|
|
|
|