""" FROM https://github.com/hasan-rakibul/UPLME/tree/main """ import torch from transformers import AutoTokenizer from src.paired_texts_modelling import LitPairedTextModel _device = None _model = None _tokeniser = None def load_model(ckpt_path: str): global _model, _tokeniser, _device plm_name = "roberta-base" _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _model = LitPairedTextModel.load_from_checkpoint(ckpt_path).to(_device).eval() _tokeniser = AutoTokenizer.from_pretrained( plm_name, use_fast=True, add_prefix_space=False # the first word is tokenised differently if not a prefix space, but it might decrease performance, so False (09/24) ) @torch.inference_mode() def predict(essay: str, article: str) -> tuple[float, float]: max_length = 512 toks = _tokeniser( essay, article, truncation=True, max_length=max_length, return_tensors="pt" ).to(_device) mean, var, _ = _model(toks) return mean.item(), var.item()