|
|
import random |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import Qwen3PreTrainedModel, Qwen3Config, Qwen3Model |
|
|
from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP |
|
|
|
|
|
|
|
|
class TokenCompressor(nn.Module): |
|
|
""" |
|
|
Adaptive Token Compression Module |
|
|
For sequences exceeding the threshold length, use adaptive_avg_pool1d for compression |
|
|
Compressed length = threshold + excess_part * compression_ratio |
|
|
""" |
|
|
|
|
|
def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3): |
|
|
super().__init__() |
|
|
self.length_threshold = length_threshold |
|
|
self.compression_ratio = compression_ratio |
|
|
|
|
|
def forward( |
|
|
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Perform adaptive compression on token embeddings |
|
|
Args: |
|
|
token_embeddings: [batch_size, seq_len, hidden_size] |
|
|
attention_mask: [batch_size, seq_len] |
|
|
Returns: |
|
|
compressed_embeddings: Compressed embeddings |
|
|
compressed_mask: Compressed attention mask |
|
|
""" |
|
|
padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left' |
|
|
|
|
|
compressed_embeddings_list = [] |
|
|
compressed_masks_list = [] |
|
|
for text_idx in range(token_embeddings.shape[0]): |
|
|
|
|
|
real_length = int(attention_mask[text_idx].sum().item()) |
|
|
if real_length <= self.length_threshold: |
|
|
|
|
|
if padding_side == 'left': |
|
|
|
|
|
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :] |
|
|
else: |
|
|
|
|
|
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :] |
|
|
compressed_embeddings_list.append(valid_embeddings) |
|
|
compressed_masks_list.append([1] * real_length) |
|
|
else: |
|
|
target_length = int( |
|
|
self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio |
|
|
) |
|
|
|
|
|
if padding_side == 'left': |
|
|
|
|
|
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :] |
|
|
else: |
|
|
|
|
|
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :] |
|
|
|
|
|
|
|
|
compressed_embeddings_list.append( |
|
|
F.adaptive_avg_pool1d( |
|
|
valid_embeddings.transpose(1, 2), target_length |
|
|
).transpose(1, 2) |
|
|
) |
|
|
|
|
|
compressed_masks_list.append([1] * target_length) |
|
|
|
|
|
|
|
|
new_seq_len = max((len(_mask) for _mask in compressed_masks_list)) |
|
|
new_attention_mask = torch.tensor( |
|
|
[ |
|
|
_mask + [0] * (new_seq_len - len(_mask)) |
|
|
if padding_side == "right" |
|
|
else |
|
|
[0] * (new_seq_len - len(_mask)) + _mask |
|
|
for _mask in compressed_masks_list |
|
|
], |
|
|
dtype=torch.long, |
|
|
device=token_embeddings.device |
|
|
) |
|
|
|
|
|
|
|
|
batch_size = token_embeddings.shape[0] |
|
|
hidden_size = token_embeddings.shape[2] |
|
|
new_token_embeddings = torch.zeros( |
|
|
batch_size, new_seq_len, hidden_size, |
|
|
dtype=token_embeddings.dtype, |
|
|
device=token_embeddings.device |
|
|
) |
|
|
|
|
|
for idx, compressed_emb in enumerate(compressed_embeddings_list): |
|
|
seq_len = compressed_emb.shape[1] |
|
|
if padding_side == "right": |
|
|
new_token_embeddings[idx, :seq_len, :] = compressed_emb.squeeze(0) |
|
|
else: |
|
|
|
|
|
new_token_embeddings[idx, -seq_len:, :] = compressed_emb.squeeze(0) |
|
|
|
|
|
return new_token_embeddings, new_attention_mask |
|
|
|
|
|
|
|
|
class JasperV2Encoder(Qwen3PreTrainedModel): |
|
|
|
|
|
def __init__(self, config: Qwen3Config): |
|
|
super().__init__(config) |
|
|
self.model = Qwen3Model(config) |
|
|
self.jasper_mlp = Qwen3MLP(config=config) |
|
|
self.linear_1 = nn.Linear(in_features=config.hidden_size, out_features=2048, bias=True) |
|
|
self.token_compressor = TokenCompressor(length_threshold=80, compression_ratio=0.5) |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
*args, |
|
|
**kwargs |
|
|
) -> torch.Tensor: |
|
|
|
|
|
token_embeddings = self.model.embed_tokens(input_ids) |
|
|
token_embeddings = self.jasper_mlp(token_embeddings) |
|
|
|
|
|
self.token_compressor.compression_ratio = kwargs.get( |
|
|
"compression_ratio", |
|
|
self.token_compressor.compression_ratio |
|
|
) |
|
|
compressed_token_embeddings, attention_mask = self.token_compressor(token_embeddings, attention_mask) |
|
|
compressed_token_embeddings = self.model( |
|
|
inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask |
|
|
)["last_hidden_state"] |
|
|
|
|
|
|
|
|
input_mask_expanded = ( |
|
|
attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to( |
|
|
compressed_token_embeddings.dtype) |
|
|
) |
|
|
sum_embeddings = torch.sum(compressed_token_embeddings * input_mask_expanded, 1) |
|
|
sum_mask = input_mask_expanded.sum(1) |
|
|
sum_mask = torch.clamp(sum_mask, min=1e-9) |
|
|
vector = sum_embeddings / sum_mask |
|
|
return self.linear_1(vector) |
|
|
|