File size: 6,395 Bytes
f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 06a100f f47c8f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Qwen3PreTrainedModel, Qwen3Config, Qwen3Model
from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP
class TokenCompressor(nn.Module):
"""
Adaptive Token Compression Module
For sequences exceeding the threshold length, use adaptive_avg_pool1d for compression
Compressed length = threshold + excess_part * compression_ratio
"""
def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3):
super().__init__()
self.length_threshold = length_threshold
self.compression_ratio = compression_ratio
def forward(
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform adaptive compression on token embeddings
Args:
token_embeddings: [batch_size, seq_len, hidden_size]
attention_mask: [batch_size, seq_len]
Returns:
compressed_embeddings: Compressed embeddings
compressed_mask: Compressed attention mask
"""
padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left'
compressed_embeddings_list = []
compressed_masks_list = []
for text_idx in range(token_embeddings.shape[0]):
# Get the effective length of current sample
real_length = int(attention_mask[text_idx].sum().item())
if real_length <= self.length_threshold:
# Extract valid token embeddings based on padding direction
if padding_side == 'left':
# Left padding: valid tokens are on the right
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
else:
# Right padding: valid tokens are on the left
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
compressed_embeddings_list.append(valid_embeddings)
compressed_masks_list.append([1] * real_length)
else:
target_length = int(
self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio
)
# Extract valid token embeddings based on padding direction
if padding_side == 'left':
# Left padding: valid tokens are on the right
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
else:
# Right padding: valid tokens are on the left
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
# Use adaptive_avg_pool1d for compression
compressed_embeddings_list.append(
F.adaptive_avg_pool1d(
valid_embeddings.transpose(1, 2), target_length
).transpose(1, 2)
)
# print("valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape",valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape)
compressed_masks_list.append([1] * target_length)
# Reassemble token_embeddings and attention_mask
new_seq_len = max((len(_mask) for _mask in compressed_masks_list))
new_attention_mask = torch.tensor(
[
_mask + [0] * (new_seq_len - len(_mask))
if padding_side == "right"
else
[0] * (new_seq_len - len(_mask)) + _mask
for _mask in compressed_masks_list
],
dtype=torch.long,
device=token_embeddings.device
)
# Generate new token_embeddings
batch_size = token_embeddings.shape[0]
hidden_size = token_embeddings.shape[2]
new_token_embeddings = torch.zeros(
batch_size, new_seq_len, hidden_size,
dtype=token_embeddings.dtype,
device=token_embeddings.device
)
for idx, compressed_emb in enumerate(compressed_embeddings_list):
seq_len = compressed_emb.shape[1]
if padding_side == "right":
new_token_embeddings[idx, :seq_len, :] = compressed_emb.squeeze(0)
else:
# print("new_token_embeddings.shape,compressed_emb.shape",new_token_embeddings.shape,compressed_emb.shape)
new_token_embeddings[idx, -seq_len:, :] = compressed_emb.squeeze(0)
return new_token_embeddings, new_attention_mask
class JasperV2Encoder(Qwen3PreTrainedModel):
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.model = Qwen3Model(config)
self.jasper_mlp = Qwen3MLP(config=config)
self.linear_1 = nn.Linear(in_features=config.hidden_size, out_features=2048, bias=True)
self.token_compressor = TokenCompressor(length_threshold=80, compression_ratio=0.5)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
*args,
**kwargs
) -> torch.Tensor:
# token_embeddings.shape batch_size*seq_len*hidden_size
token_embeddings = self.model.embed_tokens(input_ids)
token_embeddings = self.jasper_mlp(token_embeddings)
self.token_compressor.compression_ratio = kwargs.get(
"compression_ratio",
self.token_compressor.compression_ratio
)
compressed_token_embeddings, attention_mask = self.token_compressor(token_embeddings, attention_mask)
compressed_token_embeddings = self.model(
inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask
)["last_hidden_state"]
# Generate sentence vector
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to(
compressed_token_embeddings.dtype)
)
sum_embeddings = torch.sum(compressed_token_embeddings * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
vector = sum_embeddings / sum_mask
return self.linear_1(vector)
|