File size: 361 Bytes
f47c8f7 |
1 2 3 4 5 6 7 8 9 10 |
import torch
from sentence_transformers.models import Transformer as BaseTransformer
class JasperTransformer(BaseTransformer):
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
vectors = self.auto_model(**features, **kwargs)
features.update({"sentence_embedding": vectors})
return features
|