asap-bb's picture
Upload model.py for custom architecture
190a1de verified
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random
import time
import math
import tiktoken
import inspect
import os
from dataclasses import dataclass, asdict
from huggingface_hub import PyTorchModelHubMixin
from typing import Optional
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
@dataclass
class ModelConfig:
vocab_size: int
num_dims: int # number of dimensions
num_heads: int # number of query heads
num_kv_heads: int # number of key/value heads
num_layers: int # total transformer layers
ffn_hidden_dims: int # hidden dimension for FFN/FFNwMoE
context_len: int # maximum context length
use_cache: bool # enable KV-caching
use_flash: bool # use Flash Attention
use_moe: bool # enable mixture-of-experts
moe_num_experts: int # total number of experts
moe_active_experts: int # number of experts per token (top_k)
moe_eps: float = 1e-6 # epsilon for router stability
moe_aux_loss_coef: float = 0.01 # coefficient for auxiliary loss
moe_shared_experts: int = 0 # number of shared experts (DeepSeekMoE)
use_lossfreebalance: bool = False # use Auxiliary-loss-free load balancing strategy for mixture-of-experts from DeepSeek https://arxiv.org/pdf/2408.15664
rmsnorm_eps: float = 1e-6
rope_theta: float = 1e5
ffn_dim_multiplier: Optional[int] = None # optional multiplier to compute ffn_hidden_dims
def items(self):
"""Return dict items for PyTorchModelHubMixin compatibility"""
return asdict(self).items()
# Helper function for RoPE
def repeat_kv(vct: torch.Tensor, n_times: int):
c_batch_size, c_context_len, num_kv_heads, c_dim = vct.shape
if n_times == 1:
return vct
else:
return (
vct[:, :, :, None, :]
.expand(c_batch_size, c_context_len, num_kv_heads, n_times, c_dim)
.reshape(c_batch_size, c_context_len, num_kv_heads * n_times, c_dim)
)
class Rotary(nn.Module):
def __init__(self, config):
super(Rotary, self).__init__()
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, config.num_dims // config.num_heads, 2).float() / (config.num_dims // config.num_heads)))
self.register_buffer('inv_freq', inv_freq, persistent=False)
self.seq_len_saved = None
self.cos_saved = None
self.sin_saved = None
def forward(self, x, seq_dim=1):
seq_len = x.size(seq_dim)
# Only recompute the cosine and sine matrices if the sequence length has changed.
if seq_len != self.seq_len_saved:
self.seq_len_saved = seq_len
pos = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
# Compute the outer product between positions and inverse frequencies.
freqs = torch.einsum("i,j->ij", pos, self.inv_freq) # (seq_len, inv_freq.shape[0])
# Duplicate the freqs along the last dimension to create pairs.
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_saved = emb.cos()
self.sin_saved = emb.sin()
return self.cos_saved, self.sin_saved
class RMSNorm(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.g = nn.Parameter(torch.ones(config.num_dims))
self.eps = config.rmsnorm_eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.g * self._norm(x.float()).type_as(x)
class GroupedQueryAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.use_cache = config.use_cache
self.use_flash = config.use_flash
self.num_heads = config.num_heads
self.num_kv_heads = config.num_heads if config.num_kv_heads is None else config.num_kv_heads
self.num_rep = self.num_heads // self.num_kv_heads
self.head_dim = config.num_dims // self.num_heads
self.wq = nn.Linear(config.num_dims, config.num_dims, bias=False)
nn.init.normal_(self.wq.weight, mean=0, std=1/math.sqrt(config.num_dims))
self.wk = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
nn.init.normal_(self.wk.weight, mean=0, std=1/math.sqrt(config.num_dims))
self.wv = nn.Linear(config.num_dims, self.num_kv_heads * self.head_dim, bias=False)
nn.init.normal_(self.wv.weight, mean=0, std=1/math.sqrt(config.num_dims))
self.wo = nn.Linear(config.num_dims, config.num_dims, bias=False)
self.cache_k = None
self.cache_v = None
def rotate_half(self, x):
half = x.shape[-1] // 2
first_half, second_half = x[..., :half], x[..., half:]
return torch.cat([-second_half, first_half], dim=-1)
def apply_rotary_pos(self, q, k, cos, sin):
q_rot = q * cos + self.rotate_half(q) * sin
k_rot = k * cos + self.rotate_half(k) * sin
return q_rot, k_rot
def update_kv_cache(self, batch_size, start_pos, context_len, keys, values, device):
# Initialize cache if not exist
if self.cache_k is None:
self.cache_k = torch.zeros(
(batch_size, self.config.context_len, self.num_kv_heads, self.head_dim),
device=device
)
self.cache_v = torch.zeros(
(batch_size, self.config.context_len, self.num_kv_heads, self.head_dim),
device=device
)
# Update cache
self.cache_k[:batch_size, start_pos:start_pos + context_len] = keys
self.cache_v[:batch_size, start_pos:start_pos + context_len] = values
return (self.cache_k[:batch_size, :start_pos + context_len],
self.cache_v[:batch_size, :start_pos + context_len])
def forward(self, x, cos, sin, start_pos = 0, use_cache=None, rope_position_offset: int = 0):
c_batch_size, c_context_len, c_dim = x.shape # c_context_len = 1
# use_cacheがNoneなら自身のconfig値を使用、指定されていればその値を使用
effective_use_cache = use_cache if use_cache is not None else self.use_cache
#print(f"effective_use_cache: {effective_use_cache}, c_context_len: {c_context_len}, x.shape: {x.shape}")
if effective_use_cache and c_context_len == 1:
# Cache branch
q = self.wq(x[:, -1, :])
k = self.wk(x[:, -1, :])
v = self.wv(x[:, -1, :])
q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim).transpose(1, 2) # B, T, qh, hs
k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, T, kh, hs
v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, T, vh, hs
# freqs_complex = freqs_complex[-1:]
# queries = apply_rotary_pos(q, freqs_complex, device=x.device)
# keys = apply_rotary_pos(k, freqs_complex, device=x.device)
# 特定位置のcos, sinを直接計算
inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, self.head_dim, 2, device=x.device).float() / self.head_dim))
actual_rope_pos = start_pos + rope_position_offset
pos = torch.tensor([actual_rope_pos], device=x.device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", pos, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_pos = emb.cos().unsqueeze(0) # [1, 1, head_dim]
sin_pos = emb.sin().unsqueeze(0) # [1, 1, head_dim]
queries, keys = self.apply_rotary_pos(q, k, cos_pos, sin_pos)
cached_keys, cached_values = self.update_kv_cache(batch_size=c_batch_size, start_pos=start_pos, context_len=c_context_len, keys=keys.transpose(1,2), values=v.transpose(1,2), device=x.device)
keys, v = cached_keys.transpose(1,2), cached_values.transpose(1,2)
else:
# Non-cache branch (process the entire sequence normally)
q = self.wq(x)
k = self.wk(x)
v = self.wv(x)
q = q.view(c_batch_size, c_context_len, self.num_heads, self.head_dim).transpose(1, 2) # B, qh, T, hs
k = k.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, kh, T, hs
v = v.view(c_batch_size, c_context_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # B, vh, T, hs
queries, keys = self.apply_rotary_pos(q, k, cos, sin)
# queries = apply_rotary_pos(q, freqs_complex, device=x.device)
# keys = apply_rotary_pos(k, freqs_complex, device=x.device)
if effective_use_cache: _k, _v = self.update_kv_cache(batch_size=c_batch_size, start_pos=start_pos, context_len=c_context_len, keys=keys.transpose(1,2), values=v.transpose(1,2), device=x.device)
if self.use_flash:
# For cache processing, we need to ensure proper GQA handling
if effective_use_cache and x.shape[1] == 1:
# Incremental processing: manually expand keys/values for GQA
# Only expand if keys have fewer heads than queries (GQA case)
if keys.size(1) != queries.size(1):
# Transpose to [batch, seq_len, heads, dim] for repeat_kv function
keys_for_repeat = keys.transpose(1, 2) # [B, T, H, D]
v_for_repeat = v.transpose(1, 2) # [B, T, H, D]
keys_expanded_temp = repeat_kv(keys_for_repeat, self.num_rep)
values_expanded_temp = repeat_kv(v_for_repeat, self.num_rep)
# Transpose back to [batch, heads, seq_len, dim]
keys_expanded = keys_expanded_temp.transpose(1, 2) # [B, H, T, D]
values_expanded = values_expanded_temp.transpose(1, 2) # [B, H, T, D]
else:
keys_expanded = keys
values_expanded = v
# Manual attention for incremental cache case
attention = torch.matmul(queries, keys_expanded.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
total_length = keys_expanded.size(2)
mask = torch.arange(total_length, device=attention.device).unsqueeze(0) <= (start_pos + x.shape[1] - 1)
mask = mask.unsqueeze(0).unsqueeze(0)
attention = attention.masked_fill(~mask, float("-inf"))
attention = F.softmax(attention, dim=-1)
output = torch.matmul(attention, values_expanded)
else:
# Non-incremental: use flash attention normally
output = F.scaled_dot_product_attention(queries, keys, v, is_causal=True, enable_gqa=True)
else: # Calculate Grouped Query Attention manually
keys = repeat_kv(keys, self.num_rep)
values = repeat_kv(v, self.num_rep)
attention = torch.matmul(queries, keys.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
if effective_use_cache and x.shape[1] == 1:
total_length = keys.size(2)
# For autoregressive generation, the query (which is at the latest position) should only attend to keys at indices <= current token.
# Create a mask: allowed positions are indices < total_length (i.e. all in the cache)
mask = torch.arange(total_length, device=attention.device).unsqueeze(0) <= (start_pos + x.shape[1] - 1)
mask = mask.unsqueeze(0).unsqueeze(0) # shape: (1, 1, 1, total_length)
attention = attention.masked_fill(~mask, float("-inf"))
attention = F.softmax(attention, dim=-1)
output = torch.matmul(attention, values)
else: # Do not use kv_cache
attention = torch.tril(attention[:, :, :c_context_len, :c_context_len])
attention = attention.masked_fill(attention == 0, float("-inf"))
attention = F.softmax(attention, dim=-1).type_as(queries)
output = torch.matmul(attention, values)
output = output.transpose(2, 1).contiguous().view(c_batch_size, c_context_len, c_dim)
return self.wo(output)
class FeedForward(nn.Module):
"""
Default Feed Forward Layer.
"""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.ffn_hidden_dims
self.w1 = nn.Linear(config.num_dims, self.hidden_dim, bias=False)
self.w2 = nn.Linear(self.hidden_dim, config.num_dims, bias=False)
self.w3 = nn.Linear(config.num_dims, self.hidden_dim, bias=False)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor):
return self.w2(self.act(self.w1(x)) * self.w3(x)), None
class FFNwMoE(nn.Module):
"""
Feed Forward with MoE with optional shared experts.
Returns after forward:
output: Combined outputs from experts
aux_loss: Auxiliary loss tensor or routing metadata
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.hidden_dim = config.ffn_hidden_dims
self.moe_active_experts = config.moe_active_experts # top_k
self.moe_aux_loss_coef = config.moe_aux_loss_coef
self.moe_eps = config.moe_eps
self.moe_shared_experts = config.moe_shared_experts
self.num_experts = config.moe_num_experts
self.use_lossfreebalance = config.use_lossfreebalance
self.router = nn.Linear(config.num_dims, self.num_experts, bias=False)
self.experts = nn.ModuleList()
for _ in range(self.num_experts):
self.experts.append(
nn.ModuleList([
nn.Linear(config.num_dims, self.hidden_dim, bias=False),
nn.Linear(self.hidden_dim, config.num_dims, bias=False),
nn.Linear(config.num_dims, self.hidden_dim, bias=False)
]))
# shared experts (for DeepSeekMoE)
self.shared_experts = nn.ModuleList()
for _ in range(self.moe_shared_experts):
self.shared_experts.append(
nn.ModuleList([
nn.Linear(config.num_dims, self.hidden_dim, bias=False),
nn.Linear(self.hidden_dim, config.num_dims, bias=False),
nn.Linear(config.num_dims, self.hidden_dim, bias=False)
]))
# Auxiliary-loss-free load balancing strategy for mixture-of-experts from DeepSeek https://arxiv.org/pdf/2408.15664
if self.use_lossfreebalance:
self.expert_biases = nn.Parameter(torch.zeros(self.num_experts))
def forward(self, x: torch.Tensor):
c_batch_size, c_context_len, c_dim = x.shape
x_flat = x.view(-1, c_dim) #c_batch_size * c_context_len, c_dim
router_out = self.router(x_flat)
router_probs = F.softmax(router_out, dim=-1)
_, topk_indices = router_out.topk(self.moe_active_experts, dim=-1)
aux_loss, topk_probs = self._compute_aux_loss(router_out, router_probs, topk_indices)
output = self._compute_expert_outputs(x_flat, topk_indices, topk_probs, router_probs)
return output.view(c_batch_size, c_context_len, c_dim), aux_loss
def _compute_aux_loss(self, router_out, router_probs, topk_indices):
"""
Computes the auxiliary loss based on whether loss-free balancing is used or not.
"""
if not self.use_lossfreebalance:
topk_probs, _ = router_probs.topk(self.moe_active_experts, dim=-1)
expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float()
density = expert_mask.mean(dim=0)
router_prob_mean = router_probs.mean(dim=0)
aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts
else: # if use_lossfreebalance
router_out = router_out + self.expert_biases
router_probs = torch.sigmoid(router_out) # from https://arxiv.org/pdf/2408.15664 paper
topk_probs = router_probs.gather(-1, topk_indices)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
# In the case of Auxiliary-loss-free load balancing we pass router_probs, topk_indices as aux_loss for further calculations
aux_loss = (router_probs, topk_indices)
return aux_loss, topk_probs
def _compute_expert_outputs(self, x_flat, topk_indices, topk_probs, router_probs):
"""
Compute the output of the experts and shared experts if needed
"""
output = torch.zeros_like(x_flat)
for i in range(self.moe_active_experts):
expert_index = topk_indices[:, i]
expert_probs = topk_probs[:, i]
for expert_id in range(self.num_experts):
idx = (expert_id == expert_index).nonzero().squeeze()
if idx.numel() == 0:
continue
x_for_expert = x_flat[idx]
w1, w2, w3 = self.experts[expert_id]
expert_output = w2(F.silu(w1(x_for_expert)) * w3(x_for_expert))
output[idx] += expert_output * expert_probs[idx].unsqueeze(-1)
# shared experts(for DeepSeekMoE)
for shared_expert_id in range(self.moe_shared_experts):
w1, w2, w3 = self.shared_experts[shared_expert_id]
expert_output = w2(F.silu(w1(x_flat)) * w3(x_flat))
output = output + expert_output
return output
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = GroupedQueryAttention(config)
if config.use_moe:
self.ffn = FFNwMoE(config)
else:
self.ffn = FeedForward(config)
self.norm_attention = torch.nn.modules.normalization.RMSNorm(config.num_dims, config.rmsnorm_eps) # you also can use RMSNorm(config)
self.norm_ffn = torch.nn.modules.normalization.RMSNorm(config.num_dims, config.rmsnorm_eps) # you also can use RMSNorm(config)
def forward(self, x, cos, sin, start_pos, use_cache=None, rope_position_offset: int = 0):
x = x + self.attention(
self.norm_attention(x),
cos, sin, start_pos, use_cache=use_cache, rope_position_offset=rope_position_offset
)
ffn_out, aux_loss = self.ffn(
self.norm_ffn(x)
)
x = x + ffn_out
return x, aux_loss
class Transformer(nn.Module, PyTorchModelHubMixin): # extending PyTorchModelHubMixin for save weights as safetensors
def __init__(self, config: ModelConfig, **kwargs):
super().__init__()
self.vocab_size = config.vocab_size
self.num_dims = config.num_dims
self.num_heads = config.num_heads
self.context_len = config.context_len
self.use_moe = config.use_moe
self.use_lossfreebalance = config.use_lossfreebalance and self.use_moe
self.num_layers = config.num_layers
self.rotary_emb = Rotary(config)
# Calculation of hidden_dim for FFN/FFNwMoE
# multiple_of = 4
# ffn_dim_multiplier = config.ffn_dim_multiplier
hidden_dim = 4 * config.num_dims
# hidden_dim = int(2 * config.num_dims / 3)
# if ffn_dim_multiplier is not None:
# hidden_dim = int(ffn_dim_multiplier * hidden_dim)
# config.ffn_hidden_dims = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.tokens_embedding = nn.Embedding(self.vocab_size, self.num_dims)
self.blocks = nn.ModuleList()
for _ in range(self.num_layers):
self.blocks.append(Block(config))
self.norm = torch.nn.modules.normalization.RMSNorm(config.num_dims, config.rmsnorm_eps) # you also can use RMSNorm(config)
self.ll_head = nn.Linear(self.num_dims, self.vocab_size, bias=False)
self.tokens_embedding.weight = self.ll_head.weight
# torch.nn.init.normal_(self.ll_head.weight, mean=0.0, std=0.02)
# torch.nn.init.normal_(self.tokens_embedding.weight, mean=0.0, std=0.02)
# self.freqs_complex = None # precompute_theta_pos_frequencies(self.num_dims // self.num_heads, self.context_len * 2, device=config.device)
def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None, start_pos: int = 0, use_cache=None, rope_position_offset: int = 0):
_, seq_len = x.shape
x = self.tokens_embedding(x)
cos, sin = self.rotary_emb(x, seq_dim=1)
# if self.freqs_complex == None:
# self.freqs_complex = precompute_theta_pos_frequencies(self.num_dims // self.num_heads, self.context_len * 2, device=x.device)
# freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
total_aux_loss = 0
for block in self.blocks:
x, aux_loss = block(x, cos, sin, start_pos=start_pos, use_cache=use_cache, rope_position_offset=rope_position_offset)
if self.use_moe and not self.use_lossfreebalance:
total_aux_loss += aux_loss
x = self.norm(x)
logits = self.ll_head(x)
if targets is None:
loss = None
ce_loss = None
else:
c_batch_size, c_context_len, c_dim = logits.shape
logits = logits.view(c_batch_size*c_context_len, c_dim)
targets = targets.view(c_batch_size*c_context_len)
ce_loss = F.cross_entropy(logits, targets)
if self.use_moe and not self.use_lossfreebalance: loss = ce_loss + total_aux_loss # in this case, ce_loss its loss w/o aux_loss
else: # if we want to use Auxiliary-loss-free load balancing we pass router_probs, topk_indices as ce_loss
# Also, work when moe is not used
loss = ce_loss
ce_loss = aux_loss
return logits, loss, ce_loss
@torch.no_grad()
def generate(self, x: torch.Tensor, max_tokens: int, temperature: float = 1.0, top_k: int = 50,
top_p: float = 1.0, repetition_penalty: float = 1.0, use_cache: bool = False):
"""
Generate text from x up to max_tokens
Args:
x: Input token IDs [batch_size, seq_len]
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more random)
top_k: Keep only top k tokens (set to None to disable)
top_p: Nucleus sampling threshold (cumulative probability)
repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition)
use_cache: Whether to use KV caching for efficiency
"""
initial_length = x.shape[1] # 初期入力の長さを記録
for c_tkn_pos in range(max_tokens):
if use_cache:
if c_tkn_pos == 0:
rope_offset = 0 # 最初は入力全体を処理
logits, _, ce_loss = self.forward(x, start_pos=0, use_cache=use_cache, rope_position_offset=rope_offset)
else:
# start_posは実際のシーケンス位置を指定(キャッシュ位置とRoPE位置の両方)
actual_start_pos = initial_length + c_tkn_pos - 1
rope_offset = 0 # start_posが既に正しい位置なので調整不要
logits, _, ce_loss = self.forward(x[:, -1:], start_pos=actual_start_pos, use_cache=use_cache, rope_position_offset=rope_offset)
else:
logits, _, ce_loss = self.forward(x, use_cache=use_cache)
logits = logits[:, -1, :] / temperature
# Apply repetition penalty
if repetition_penalty != 1.0:
logits = self._apply_repetition_penalty(logits, x, repetition_penalty)
# Apply top-k filtering
if top_k is not None and top_k > 0:
logits = self._apply_top_k(logits, top_k)
# Apply top-p (nucleus) filtering
if top_p < 1.0:
logits = self._apply_top_p(logits, top_p)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
x = torch.cat((x, next_token), dim=1)
return x
def _apply_repetition_penalty(self, logits: torch.Tensor, input_ids: torch.Tensor, penalty: float):
"""Apply repetition penalty to logits based on previous tokens"""
batch_size, vocab_size = logits.shape
for batch_idx in range(batch_size):
for token_id in input_ids[batch_idx].unique():
if logits[batch_idx, token_id] < 0:
logits[batch_idx, token_id] *= penalty
else:
logits[batch_idx, token_id] /= penalty
return logits
def _apply_top_k(self, logits: torch.Tensor, top_k: int):
"""Apply top-k filtering to logits"""
top_k = min(top_k, logits.size(-1))
tkl, idx = torch.topk(logits, top_k)
logits[logits < tkl[:, [-1]]] = -float('Inf')
return logits
def _apply_top_p(self, logits: torch.Tensor, top_p: float):
"""Apply top-p (nucleus) filtering to logits"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for batch_idx in range(logits.shape[0]):
indices_to_remove = sorted_indices[batch_idx][sorted_indices_to_remove[batch_idx]]
logits[batch_idx][indices_to_remove] = -float('Inf')
return logits
def main():
# config = ModelConfig(
# device = 'cuda' if torch.cuda.is_available() else 'cpu',
# vocab_size = 50304,
# num_dims = 1024,
# num_heads = 16,
# num_kv_heads = 4,
# num_layers = 16,
# ffn_hidden_dims = 1024 * 4,
# rmsnorm_eps = 1e-6,
# rope_theta = 1e5,
# context_len = 1024,
# use_cache = False,
# use_flash = False,
# use_moe = False,
# moe_num_experts = 6,
# moe_active_experts = 1,
# moe_eps = 1e-6,
# moe_aux_loss_coef = 0.01,
# moe_shared_experts = 0,
# use_lossfreebalance = False,
# )
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# SEED = 1337
# torch.manual_seed(SEED)
# if device == 'cuda':
# torch.cuda.manual_seed(SEED)
# model = Transformer(config)
# model = model.to(device)
# model = torch.compile(model)
# print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')
pass
if __name__ == "__main__":
main()