Jasper-Token-Compression-600M / modeling_qwen3_jasper.py
infgrad's picture
Upload modeling_qwen3_jasper.py
06a100f verified
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Qwen3PreTrainedModel, Qwen3Config, Qwen3Model
from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP
class TokenCompressor(nn.Module):
"""
Adaptive Token Compression Module
For sequences exceeding the threshold length, use adaptive_avg_pool1d for compression
Compressed length = threshold + excess_part * compression_ratio
"""
def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3):
super().__init__()
self.length_threshold = length_threshold
self.compression_ratio = compression_ratio
def forward(
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform adaptive compression on token embeddings
Args:
token_embeddings: [batch_size, seq_len, hidden_size]
attention_mask: [batch_size, seq_len]
Returns:
compressed_embeddings: Compressed embeddings
compressed_mask: Compressed attention mask
"""
padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left'
compressed_embeddings_list = []
compressed_masks_list = []
for text_idx in range(token_embeddings.shape[0]):
# Get the effective length of current sample
real_length = int(attention_mask[text_idx].sum().item())
if real_length <= self.length_threshold:
# Extract valid token embeddings based on padding direction
if padding_side == 'left':
# Left padding: valid tokens are on the right
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
else:
# Right padding: valid tokens are on the left
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
compressed_embeddings_list.append(valid_embeddings)
compressed_masks_list.append([1] * real_length)
else:
target_length = int(
self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio
)
# Extract valid token embeddings based on padding direction
if padding_side == 'left':
# Left padding: valid tokens are on the right
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
else:
# Right padding: valid tokens are on the left
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
# Use adaptive_avg_pool1d for compression
compressed_embeddings_list.append(
F.adaptive_avg_pool1d(
valid_embeddings.transpose(1, 2), target_length
).transpose(1, 2)
)
# print("valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape",valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape)
compressed_masks_list.append([1] * target_length)
# Reassemble token_embeddings and attention_mask
new_seq_len = max((len(_mask) for _mask in compressed_masks_list))
new_attention_mask = torch.tensor(
[
_mask + [0] * (new_seq_len - len(_mask))
if padding_side == "right"
else
[0] * (new_seq_len - len(_mask)) + _mask
for _mask in compressed_masks_list
],
dtype=torch.long,
device=token_embeddings.device
)
# Generate new token_embeddings
batch_size = token_embeddings.shape[0]
hidden_size = token_embeddings.shape[2]
new_token_embeddings = torch.zeros(
batch_size, new_seq_len, hidden_size,
dtype=token_embeddings.dtype,
device=token_embeddings.device
)
for idx, compressed_emb in enumerate(compressed_embeddings_list):
seq_len = compressed_emb.shape[1]
if padding_side == "right":
new_token_embeddings[idx, :seq_len, :] = compressed_emb.squeeze(0)
else:
# print("new_token_embeddings.shape,compressed_emb.shape",new_token_embeddings.shape,compressed_emb.shape)
new_token_embeddings[idx, -seq_len:, :] = compressed_emb.squeeze(0)
return new_token_embeddings, new_attention_mask
class JasperV2Encoder(Qwen3PreTrainedModel):
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.model = Qwen3Model(config)
self.jasper_mlp = Qwen3MLP(config=config)
self.linear_1 = nn.Linear(in_features=config.hidden_size, out_features=2048, bias=True)
self.token_compressor = TokenCompressor(length_threshold=80, compression_ratio=0.5)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
*args,
**kwargs
) -> torch.Tensor:
# token_embeddings.shape batch_size*seq_len*hidden_size
token_embeddings = self.model.embed_tokens(input_ids)
token_embeddings = self.jasper_mlp(token_embeddings)
self.token_compressor.compression_ratio = kwargs.get(
"compression_ratio",
self.token_compressor.compression_ratio
)
compressed_token_embeddings, attention_mask = self.token_compressor(token_embeddings, attention_mask)
compressed_token_embeddings = self.model(
inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask
)["last_hidden_state"]
# Generate sentence vector
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to(
compressed_token_embeddings.dtype)
)
sum_embeddings = torch.sum(compressed_token_embeddings * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
vector = sum_embeddings / sum_mask
return self.linear_1(vector)