File size: 361 Bytes
f47c8f7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import torch
from sentence_transformers.models import Transformer as BaseTransformer


class JasperTransformer(BaseTransformer):
    def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
        vectors = self.auto_model(**features, **kwargs)
        features.update({"sentence_embedding": vectors})
        return features