|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import random |
|
|
import time |
|
|
import math |
|
|
import tiktoken |
|
|
import inspect |
|
|
import os |
|
|
from dataclasses import dataclass, asdict |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
from typing import Optional |
|
|
|
|
|
from torch.distributed import init_process_group, destroy_process_group |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
vocab_size: int |
|
|
|
|
|
num_dims: int |
|
|
num_heads: int |
|
|
num_kv_heads: int |
|
|
num_layers: int |
|
|
ffn_hidden_dims: int |
|
|
|
|
|
context_len: int |
|
|
use_cache: bool |
|
|
use_flash: bool |
|
|
use_moe: bool |
|
|
|
|
|
moe_num_experts: int |
|
|
moe_active_experts: int |
|
|
moe_eps: float = 1e-6 |
|
|
moe_aux_loss_coef: float = 0.01 |
|
|
moe_shared_experts: int = 0 |
|
|
use_lossfreebalance: bool = False |
|
|
|
|
|
rmsnorm_eps: float = 1e-6 |
|
|
rope_theta: float = 1e5 |
|
|
|
|
|
ffn_dim_multiplier: Optional[int] = None |
|
|
|
|
|
def items(self): |
|
|
"""Return dict items for PyTorchModelHubMixin compatibility""" |
|
|
return asdict(self).items() |
|
|
|
|
|
|
|
|
|
|
|
def repeat_kv(vct: torch.Tensor, n_times: int): |
|
|
c_batch_size, c_context_len, num_kv_heads, c_dim = vct.shape |
|
|
if n_times == 1: |
|
|
return vct |
|
|
else: |
|
|
return ( |
|
|
vct[:, :, :, None, :] |
|
|
.expand(c_batch_size, c_context_len, num_kv_heads, n_times, c_dim) |
|
|
.reshape(c_batch_size, c_context_len, num_kv_heads * n_times, c_dim) |
|
|
) |
|
|
|
|
|
|
|
|
class Rotary(nn.Module): |
|
|
def __init__(self, config): |
|
|
super(Rotary, self).__init__() |
|
|
|
|
|
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, config.num_dims // config.num_heads, 2).float() / (config.num_dims // config.num_heads))) |
|
|
self.register_buffer('inv_freq', inv_freq, persistent=False) |
|
|
self.seq_len_saved = None |
|
|
self.cos_saved = None |
|
|
self.sin_saved = None |
|
|
|
|
|
def forward(self, x, seq_dim=1): |
|
|
seq_len = x.size(seq_dim) |
|
|
|
|
|
if seq_len != self.seq_len_saved: |
|
|
self.seq_len_saved = seq_len |
|
|
pos = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
|
|
|
|
|
freqs = torch.einsum("i,j->ij", pos, self.inv_freq) |
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
self.cos_saved = emb.cos() |
|
|
self.sin_saved = emb.sin() |
|
|
|
|
|
return self.cos_saved, self.sin_saved |
|
|
|
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.g = nn.Parameter(torch.ones(config.num_dims)) |
|
|
self.eps = config.rmsnorm_eps |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.g * self._norm(x.float()).type_as(x) |
|
|
|
|
|
|
|
|
class GroupedQueryAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.use_cache = config.use_cache |
|
|
self.use_flash = config.use_flash |
|
|
|
|
|
self.num_heads = config.num_heads |
|
|
self.num_kv_heads = config.num_heads if config.num_kv_heads is None else config.num_kv_heads |
|
|
|
|
|
self.num_rep = self.num_heads // self.num_kv_heads |
|
|
self.head_dim = config.num_dims // self.num_heads |
|
|
|
|
|
self.wq = nn.Linear(config.num_dims, config.num_dims, bias=False) |
|
|
nn.init.normal_(self.wq.weight, mean=0, std=1/math.sqrt(config.num_dims)) |
|
|
self.wk = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False) |
|
|
nn.init.normal_(self.wk.weight, mean=0, std=1/math.sqrt(config.num_dims)) |
|
|
self.wv = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False) |
|
|
nn.init.normal_(self.wv.weight, mean=0, std=1/math.sqrt(config.num_dims)) |
|
|
|
|
|
self.wo = nn.Linear(config.num_dims, config.num_dims, bias=False) |
|
|
|
|
|
self.cache_k = None |
|
|
self.cache_v = None |
|
|
|
|
|
|
|
|
def rotate_half(self, x): |
|
|
half = x.shape[-1] // 2 |
|
|
first_half, second_half = x[..., :half], x[..., half:] |
|
|
return torch.cat([-second_half, first_half], dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos(self, q, k, cos, sin): |
|
|
q_rot = q * cos + self.rotate_half(q) * sin |
|
|
k_rot = k * cos + self.rotate_half(k) * sin |
|
|
return q_rot, k_rot |
|
|
|
|
|
def update_kv_cache(self, batch_size, start_pos, context_len, keys, values, device): |
|
|
|
|
|
if self.cache_k is None: |
|
|
self.cache_k = torch.zeros( |
|
|
(batch_size, self.config.context_len, self.num_kv_heads, self.head_dim), |
|
|
device=device |
|
|
) |
|
|
self.cache_v = torch.zeros( |
|
|
(batch_size, self.config.context_len, self.num_kv_heads, self.head_dim), |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
self.cache_k[:batch_size, start_pos:start_pos + context_len] = keys |
|
|
self.cache_v[:batch_size, start_pos:start_pos + context_len] = values |
|
|
|
|
|
return (self.cache_k[:batch_size, :start_pos + context_len], |
|
|
self.cache_v[:batch_size, :start_pos + context_len]) |
|
|
|
|
|
|
|
|
def forward(self, x, cos, sin, start_pos = 0, use_cache=None, rope_position_offset: int = 0): |
|
|
c_batch_size, c_context_len, c_dim = x.shape |
|
|
|
|
|
|
|
|
effective_use_cache = use_cache if use_cache is not None else self.use_cache |
|
|
|
|
|
|
|
|
if effective_use_cache and c_context_len == 1: |
|
|
|
|
|
q = self.wq(x[:, -1, :]) |
|
|
k = self.wk(x[:, -1, :]) |
|
|
v = self.wv(x[:, -1, :]) |
|
|
|
|
|
q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, self.head_dim, 2, device=x.device).float() / self.head_dim)) |
|
|
actual_rope_pos = start_pos + rope_position_offset |
|
|
pos = torch.tensor([actual_rope_pos], device=x.device, dtype=inv_freq.dtype) |
|
|
freqs = torch.einsum("i,j->ij", pos, inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos_pos = emb.cos().unsqueeze(0) |
|
|
sin_pos = emb.sin().unsqueeze(0) |
|
|
|
|
|
queries, keys = self.apply_rotary_pos(q, k, cos_pos, sin_pos) |
|
|
cached_keys, cached_values = self.update_kv_cache(batch_size=c_batch_size, start_pos=start_pos, context_len=c_context_len, keys=keys.transpose(1,2), values=v.transpose(1,2), device=x.device) |
|
|
keys, v = cached_keys.transpose(1,2), cached_values.transpose(1,2) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
q = self.wq(x) |
|
|
k = self.wk(x) |
|
|
v = self.wv(x) |
|
|
|
|
|
q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
queries, keys = self.apply_rotary_pos(q, k, cos, sin) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if effective_use_cache: _k, _v = self.update_kv_cache(batch_size=c_batch_size, start_pos=start_pos, context_len=c_context_len, keys=keys.transpose(1,2), values=v.transpose(1,2), device=x.device) |
|
|
|
|
|
if self.use_flash: |
|
|
|
|
|
if effective_use_cache and x.shape[1] == 1: |
|
|
|
|
|
|
|
|
if keys.size(1) != queries.size(1): |
|
|
|
|
|
keys_for_repeat = keys.transpose(1, 2) |
|
|
v_for_repeat = v.transpose(1, 2) |
|
|
|
|
|
keys_expanded_temp = repeat_kv(keys_for_repeat, self.num_rep) |
|
|
values_expanded_temp = repeat_kv(v_for_repeat, self.num_rep) |
|
|
|
|
|
|
|
|
keys_expanded = keys_expanded_temp.transpose(1, 2) |
|
|
values_expanded = values_expanded_temp.transpose(1, 2) |
|
|
else: |
|
|
keys_expanded = keys |
|
|
values_expanded = v |
|
|
|
|
|
|
|
|
attention = torch.matmul(queries, keys_expanded.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
|
|
total_length = keys_expanded.size(2) |
|
|
mask = torch.arange(total_length, device=attention.device).unsqueeze(0) <= (start_pos + x.shape[1] - 1) |
|
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
attention = attention.masked_fill(~mask, float("-inf")) |
|
|
attention = F.softmax(attention, dim=-1) |
|
|
output = torch.matmul(attention, values_expanded) |
|
|
else: |
|
|
|
|
|
output = F.scaled_dot_product_attention(queries, keys, v, is_causal=True, enable_gqa=True) |
|
|
|
|
|
else: |
|
|
keys = repeat_kv(keys, self.num_rep) |
|
|
values = repeat_kv(v, self.num_rep) |
|
|
|
|
|
attention = torch.matmul(queries, keys.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
|
|
|
|
|
if effective_use_cache and x.shape[1] == 1: |
|
|
total_length = keys.size(2) |
|
|
|
|
|
|
|
|
mask = torch.arange(total_length, device=attention.device).unsqueeze(0) <= (start_pos + x.shape[1] - 1) |
|
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
attention = attention.masked_fill(~mask, float("-inf")) |
|
|
attention = F.softmax(attention, dim=-1) |
|
|
output = torch.matmul(attention, values) |
|
|
|
|
|
else: |
|
|
attention = torch.tril(attention[:, :, :c_context_len, :c_context_len]) |
|
|
attention = attention.masked_fill(attention == 0, float("-inf")) |
|
|
|
|
|
attention = F.softmax(attention, dim=-1).type_as(queries) |
|
|
output = torch.matmul(attention, values) |
|
|
|
|
|
output = output.transpose(2, 1).contiguous().view(c_batch_size, c_context_len, c_dim) |
|
|
return self.wo(output) |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
""" |
|
|
Default Feed Forward Layer. |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_dim = config.ffn_hidden_dims |
|
|
|
|
|
self.w1 = nn.Linear(config.num_dims, self.hidden_dim, bias=False) |
|
|
self.w2 = nn.Linear(self.hidden_dim, config.num_dims, bias=False) |
|
|
self.w3 = nn.Linear(config.num_dims, self.hidden_dim, bias=False) |
|
|
self.act = nn.SiLU() |
|
|
def forward(self, x: torch.Tensor): |
|
|
return self.w2(self.act(self.w1(x)) * self.w3(x)), None |
|
|
|
|
|
|
|
|
class FFNwMoE(nn.Module): |
|
|
""" |
|
|
Feed Forward with MoE with optional shared experts. |
|
|
Returns after forward: |
|
|
output: Combined outputs from experts |
|
|
aux_loss: Auxiliary loss tensor or routing metadata |
|
|
""" |
|
|
def __init__(self, config: ModelConfig): |
|
|
super().__init__() |
|
|
self.hidden_dim = config.ffn_hidden_dims |
|
|
|
|
|
self.moe_active_experts = config.moe_active_experts |
|
|
self.moe_aux_loss_coef = config.moe_aux_loss_coef |
|
|
self.moe_eps = config.moe_eps |
|
|
self.moe_shared_experts = config.moe_shared_experts |
|
|
self.num_experts = config.moe_num_experts |
|
|
|
|
|
self.use_lossfreebalance = config.use_lossfreebalance |
|
|
|
|
|
|
|
|
self.router = nn.Linear(config.num_dims, self.num_experts, bias=False) |
|
|
self.experts = nn.ModuleList() |
|
|
for _ in range(self.num_experts): |
|
|
self.experts.append( |
|
|
nn.ModuleList([ |
|
|
nn.Linear(config.num_dims, self.hidden_dim, bias=False), |
|
|
nn.Linear(self.hidden_dim, config.num_dims, bias=False), |
|
|
nn.Linear(config.num_dims, self.hidden_dim, bias=False) |
|
|
])) |
|
|
|
|
|
|
|
|
self.shared_experts = nn.ModuleList() |
|
|
for _ in range(self.moe_shared_experts): |
|
|
self.shared_experts.append( |
|
|
nn.ModuleList([ |
|
|
nn.Linear(config.num_dims, self.hidden_dim, bias=False), |
|
|
nn.Linear(self.hidden_dim, config.num_dims, bias=False), |
|
|
nn.Linear(config.num_dims, self.hidden_dim, bias=False) |
|
|
])) |
|
|
|
|
|
|
|
|
if self.use_lossfreebalance: |
|
|
self.expert_biases = nn.Parameter(torch.zeros(self.num_experts)) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
c_batch_size, c_context_len, c_dim = x.shape |
|
|
x_flat = x.view(-1, c_dim) |
|
|
|
|
|
router_out = self.router(x_flat) |
|
|
router_probs = F.softmax(router_out, dim=-1) |
|
|
|
|
|
_, topk_indices = router_out.topk(self.moe_active_experts, dim=-1) |
|
|
|
|
|
aux_loss, topk_probs = self._compute_aux_loss(router_out, router_probs, topk_indices) |
|
|
|
|
|
output = self._compute_expert_outputs(x_flat, topk_indices, topk_probs, router_probs) |
|
|
|
|
|
return output.view(c_batch_size, c_context_len, c_dim), aux_loss |
|
|
|
|
|
def _compute_aux_loss(self, router_out, router_probs, topk_indices): |
|
|
""" |
|
|
Computes the auxiliary loss based on whether loss-free balancing is used or not. |
|
|
""" |
|
|
if not self.use_lossfreebalance: |
|
|
topk_probs, _ = router_probs.topk(self.moe_active_experts, dim=-1) |
|
|
expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float() |
|
|
density = expert_mask.mean(dim=0) |
|
|
router_prob_mean = router_probs.mean(dim=0) |
|
|
aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts |
|
|
|
|
|
else: |
|
|
router_out = router_out + self.expert_biases |
|
|
router_probs = torch.sigmoid(router_out) |
|
|
topk_probs = router_probs.gather(-1, topk_indices) |
|
|
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
aux_loss = (router_probs, topk_indices) |
|
|
return aux_loss, topk_probs |
|
|
|
|
|
def _compute_expert_outputs(self, x_flat, topk_indices, topk_probs, router_probs): |
|
|
""" |
|
|
Compute the output of the experts and shared experts if needed |
|
|
""" |
|
|
output = torch.zeros_like(x_flat) |
|
|
|
|
|
for i in range(self.moe_active_experts): |
|
|
expert_index = topk_indices[:, i] |
|
|
expert_probs = topk_probs[:, i] |
|
|
|
|
|
for expert_id in range(self.num_experts): |
|
|
idx = (expert_id == expert_index).nonzero().squeeze() |
|
|
|
|
|
if idx.numel() == 0: |
|
|
continue |
|
|
x_for_expert = x_flat[idx] |
|
|
w1, w2, w3 = self.experts[expert_id] |
|
|
|
|
|
expert_output = w2(F.silu(w1(x_for_expert)) * w3(x_for_expert)) |
|
|
output[idx] += expert_output * expert_probs[idx].unsqueeze(-1) |
|
|
|
|
|
|
|
|
for shared_expert_id in range(self.moe_shared_experts): |
|
|
w1, w2, w3 = self.shared_experts[shared_expert_id] |
|
|
expert_output = w2(F.silu(w1(x_flat)) * w3(x_flat)) |
|
|
output = output + expert_output |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.attention = GroupedQueryAttention(config) |
|
|
if config.use_moe: |
|
|
self.ffn = FFNwMoE(config) |
|
|
else: |
|
|
self.ffn = FeedForward(config) |
|
|
|
|
|
|
|
|
self.norm_attention = torch.nn.modules.normalization.RMSNorm(config.num_dims, config.rmsnorm_eps) |
|
|
self.norm_ffn = torch.nn.modules.normalization.RMSNorm(config.num_dims, config.rmsnorm_eps) |
|
|
|
|
|
def forward(self, x, cos, sin, start_pos, use_cache=None, rope_position_offset: int = 0): |
|
|
x = x + self.attention( |
|
|
self.norm_attention(x), |
|
|
cos, sin, start_pos, use_cache=use_cache, rope_position_offset=rope_position_offset |
|
|
) |
|
|
|
|
|
ffn_out, aux_loss = self.ffn( |
|
|
self.norm_ffn(x) |
|
|
) |
|
|
x = x + ffn_out |
|
|
return x, aux_loss |
|
|
|
|
|
|
|
|
class Transformer(nn.Module, PyTorchModelHubMixin): |
|
|
def __init__(self, config: ModelConfig, **kwargs): |
|
|
super().__init__() |
|
|
|
|
|
self.vocab_size = config.vocab_size |
|
|
self.num_dims = config.num_dims |
|
|
self.num_heads = config.num_heads |
|
|
self.context_len = config.context_len |
|
|
self.use_moe = config.use_moe |
|
|
self.use_lossfreebalance = config.use_lossfreebalance and self.use_moe |
|
|
|
|
|
self.num_layers = config.num_layers |
|
|
self.rotary_emb = Rotary(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_dim = 4 * config.num_dims |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tokens_embedding = nn.Embedding(self.vocab_size, self.num_dims) |
|
|
|
|
|
self.blocks = nn.ModuleList() |
|
|
for _ in range(self.num_layers): |
|
|
self.blocks.append(Block(config)) |
|
|
|
|
|
self.norm = torch.nn.modules.normalization.RMSNorm(config.num_dims, config.rmsnorm_eps) |
|
|
self.ll_head = nn.Linear(self.num_dims, self.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.tokens_embedding.weight = self.ll_head.weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None, start_pos: int = 0, use_cache=None, rope_position_offset: int = 0): |
|
|
_, seq_len = x.shape |
|
|
|
|
|
x = self.tokens_embedding(x) |
|
|
cos, sin = self.rotary_emb(x, seq_dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_aux_loss = 0 |
|
|
|
|
|
for block in self.blocks: |
|
|
x, aux_loss = block(x, cos, sin, start_pos=start_pos, use_cache=use_cache, rope_position_offset=rope_position_offset) |
|
|
if self.use_moe and not self.use_lossfreebalance: |
|
|
total_aux_loss += aux_loss |
|
|
|
|
|
x = self.norm(x) |
|
|
logits = self.ll_head(x) |
|
|
|
|
|
|
|
|
if targets is None: |
|
|
loss = None |
|
|
ce_loss = None |
|
|
else: |
|
|
c_batch_size, c_context_len, c_dim = logits.shape |
|
|
logits = logits.view(c_batch_size*c_context_len, c_dim) |
|
|
targets = targets.view(c_batch_size*c_context_len) |
|
|
ce_loss = F.cross_entropy(logits, targets) |
|
|
|
|
|
if self.use_moe and not self.use_lossfreebalance: loss = ce_loss + total_aux_loss |
|
|
else: |
|
|
|
|
|
loss = ce_loss |
|
|
ce_loss = aux_loss |
|
|
|
|
|
return logits, loss, ce_loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, x: torch.Tensor, max_tokens: int, temperature: float = 1.0, top_k: int = 50, |
|
|
top_p: float = 1.0, repetition_penalty: float = 1.0, use_cache: bool = False): |
|
|
""" |
|
|
Generate text from x up to max_tokens |
|
|
|
|
|
Args: |
|
|
x: Input token IDs [batch_size, seq_len] |
|
|
max_tokens: Maximum number of tokens to generate |
|
|
temperature: Sampling temperature (higher = more random) |
|
|
top_k: Keep only top k tokens (set to None to disable) |
|
|
top_p: Nucleus sampling threshold (cumulative probability) |
|
|
repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition) |
|
|
use_cache: Whether to use KV caching for efficiency |
|
|
""" |
|
|
initial_length = x.shape[1] |
|
|
for c_tkn_pos in range(max_tokens): |
|
|
if use_cache: |
|
|
if c_tkn_pos == 0: |
|
|
rope_offset = 0 |
|
|
logits, _, ce_loss = self.forward(x, start_pos=0, use_cache=use_cache, rope_position_offset=rope_offset) |
|
|
else: |
|
|
|
|
|
actual_start_pos = initial_length + c_tkn_pos - 1 |
|
|
rope_offset = 0 |
|
|
logits, _, ce_loss = self.forward(x[:, -1:], start_pos=actual_start_pos, use_cache=use_cache, rope_position_offset=rope_offset) |
|
|
else: |
|
|
logits, _, ce_loss = self.forward(x, use_cache=use_cache) |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
logits = self._apply_repetition_penalty(logits, x, repetition_penalty) |
|
|
|
|
|
|
|
|
if top_k is not None and top_k > 0: |
|
|
logits = self._apply_top_k(logits, top_k) |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
logits = self._apply_top_p(logits, top_p) |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
x = torch.cat((x, next_token), dim=1) |
|
|
return x |
|
|
|
|
|
def _apply_repetition_penalty(self, logits: torch.Tensor, input_ids: torch.Tensor, penalty: float): |
|
|
"""Apply repetition penalty to logits based on previous tokens""" |
|
|
batch_size, vocab_size = logits.shape |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
for token_id in input_ids[batch_idx].unique(): |
|
|
if logits[batch_idx, token_id] < 0: |
|
|
logits[batch_idx, token_id] *= penalty |
|
|
else: |
|
|
logits[batch_idx, token_id] /= penalty |
|
|
|
|
|
return logits |
|
|
|
|
|
def _apply_top_k(self, logits: torch.Tensor, top_k: int): |
|
|
"""Apply top-k filtering to logits""" |
|
|
top_k = min(top_k, logits.size(-1)) |
|
|
tkl, idx = torch.topk(logits, top_k) |
|
|
logits[logits < tkl[:, [-1]]] = -float('Inf') |
|
|
return logits |
|
|
|
|
|
def _apply_top_p(self, logits: torch.Tensor, top_p: float): |
|
|
"""Apply top-p (nucleus) filtering to logits""" |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
for batch_idx in range(logits.shape[0]): |
|
|
indices_to_remove = sorted_indices[batch_idx][sorted_indices_to_remove[batch_idx]] |
|
|
logits[batch_idx][indices_to_remove] = -float('Inf') |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |