File size: 1,062 Bytes
6b3d060
 
 
 
 
 
 
c031815
6b3d060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
FROM https://github.com/hasan-rakibul/UPLME/tree/main
"""

import torch
from transformers import AutoTokenizer

from src.paired_texts_modelling import LitPairedTextModel

_device = None
_model = None
_tokeniser = None

def load_model(ckpt_path: str):
    global _model, _tokeniser, _device
    plm_name = "roberta-base"
    _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _model = LitPairedTextModel.load_from_checkpoint(ckpt_path).to(_device).eval()
    _tokeniser = AutoTokenizer.from_pretrained(
        plm_name,
        use_fast=True,
        add_prefix_space=False # the first word is tokenised differently if not a prefix space, but it might decrease performance, so False (09/24)
    )

@torch.inference_mode()
def predict(essay: str, article: str) -> tuple[float, float]:
    max_length = 512
    toks = _tokeniser(
        essay,
        article,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    ).to(_device)
    mean, var, _ = _model(toks)
    return mean.item(), var.item()