| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import torch.nn.functional as F |
| | import torch.distributed as dist |
| | from torch.utils.cpp_extension import load |
| | from typing import Dict, List, Optional, Tuple, Callable, Union |
| |
|
| | eps = torch.finfo(torch.float32).eps |
| |
|
| | def norm(x: torch.Tensor): |
| | return torch.rms_norm(x, (x.size(-1),), eps=eps) |
| |
|
| | class Rotary(nn.Module): |
| | def __init__(self, dim: int, max_seq_len: int): |
| | super().__init__() |
| | |
| | angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) |
| | angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) |
| | t = torch.arange(max_seq_len, dtype=torch.float32) |
| | theta = torch.einsum("i,j -> ij", t, angular_freq) |
| | self.cos = nn.Buffer(theta.cos(), persistent=False) |
| | self.sin = nn.Buffer(theta.sin(), persistent=False) |
| |
|
| | def forward(self, x_BTHD: torch.Tensor): |
| | assert self.cos.size(0) >= x_BTHD.size(-3) |
| | cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] |
| | x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) |
| | y1 = x1 * cos + x2 * sin |
| | y2 = x1 * (-sin) + x2 * cos |
| | return torch.cat((y1, y2), 3).type_as(x_BTHD) |
| |
|
| | class CausalSoftmaxAttention(nn.Module): |
| | def __init__( |
| | self, |
| | layer_id: int, |
| | layers: int, |
| | num_heads: int, |
| | vocab_size: int, |
| | input_dims: int, |
| | hidden_dims: Union[int, None] = None, |
| | ): |
| | super().__init__() |
| | |
| | self.layer_id = layer_id |
| | self.head_dim = input_dims // num_heads |
| | self.num_heads = num_heads |
| | assert input_dims % self.num_heads == 0 |
| |
|
| | H = self.num_heads |
| | N = self.head_dim |
| | C = input_dims |
| |
|
| | with torch.no_grad(): |
| | init_bounds = 0.5 / (C ** 0.5) |
| | |
| | self.q_proj = nn.Linear(C, C, bias=False) |
| | self.k_proj = nn.Linear(C, C, bias=False) |
| | self.v_proj = nn.Linear(C, C, bias=False) |
| | self.g_proj = nn.Linear(C, C, bias=False) |
| | self.o_proj = nn.Linear(C, C, bias=False) |
| |
|
| | self.rotary = Rotary(N, 2048) |
| | |
| | self.q_proj.weight.data.uniform_(-init_bounds, init_bounds) |
| | self.k_proj.weight.data.uniform_(-init_bounds, init_bounds) |
| | self.v_proj.weight.data.uniform_(-init_bounds, init_bounds) |
| | self.g_proj.weight.data.uniform_(-init_bounds, init_bounds) |
| | self.o_proj.weight.data.zero_() |
| |
|
| | def forward(self, x): |
| | B, T, C = x.size() |
| | H = self.num_heads |
| | N = C // H |
| |
|
| | def forward1(x): |
| | x = norm(x) |
| | |
| | q = self.q_proj(x).view(B, T, H, N) |
| | k = self.k_proj(x).view(B, T, H, N) |
| | v = self.v_proj(x).view(B, T, H, N) |
| | g = self.g_proj(x).sigmoid() |
| |
|
| | q, k = norm(q), norm(k) |
| | q, k = self.rotary(q), self.rotary(k) |
| | |
| | return (q, k, v, g) |
| |
|
| | (q, k, v, g) = torch.utils.checkpoint.checkpoint(forward1, x, use_reentrant=False) |
| |
|
| | x = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True).transpose(1, 2).contiguous().view(B, T, C) |
| |
|
| | x = self.o_proj(x * g) |
| | |
| | return x |
| | |
| | class MLP(nn.Module): |
| | def __init__( |
| | self, |
| | layer_id: int, |
| | layers: int, |
| | num_heads: int, |
| | vocab_size: int, |
| | input_dims: int, |
| | hidden_dims: Union[int, None] = None, |
| | ): |
| | super().__init__() |
| | |
| | self.layer_id = layer_id |
| | |
| | C = input_dims |
| | hidden_dims = hidden_dims or 4 * C |
| |
|
| | with torch.no_grad(): |
| | init_bounds = 0.5 / (C ** 0.5) |
| | |
| | self.k_proj = nn.Linear(C, hidden_dims, bias=False) |
| | self.v_proj = nn.Linear(hidden_dims, C, bias=False) |
| | |
| | self.k_proj.weight.data.uniform_(-init_bounds, init_bounds) |
| | self.v_proj.weight.data.zero_() |
| |
|
| | def forward(self, x): |
| | B, T, C = x.size() |
| |
|
| | def forward1(x): |
| | x = norm(x) |
| | |
| | k = torch.relu(self.k_proj(x)).square() |
| | |
| | return self.v_proj(k) |
| | |
| | output = torch.utils.checkpoint.checkpoint(forward1, x, use_reentrant=False) |
| | |
| | return output |
| |
|
| | class SoftmaxBlock(nn.Module): |
| | def __init__( |
| | self, |
| | layer_id: int, |
| | layers: int, |
| | num_heads: int, |
| | vocab_size: int, |
| | input_dims: int, |
| | hidden_dims: Union[int, None] = None, |
| | ): |
| | super().__init__() |
| | self.layer_id = layer_id |
| |
|
| | self.att = CausalSoftmaxAttention(layer_id, layers, num_heads, vocab_size, input_dims, hidden_dims) |
| | self.ffn = MLP(layer_id, layers, num_heads, vocab_size, input_dims, hidden_dims) |
| | |
| | def forward(self, x): |
| | xx = self.att(x) |
| | x = x + xx |
| | |
| | xx = self.ffn(x) |
| | x = x + xx |
| | |
| | return x |
| |
|
| | class Transformer(nn.Module): |
| | def __init__( |
| | self, |
| | layers: int, |
| | num_heads: int, |
| | vocab_size: int, |
| | input_dims: int, |
| | hidden_dims: Union[int, None] = None, |
| | dtype = None |
| | ): |
| | super().__init__() |
| | |
| | self.emb = nn.Embedding(vocab_size, input_dims) |
| | self.emb.weight.data.uniform_(-1e-4, 1e-4) |
| | |
| | self.blocks = nn.ModuleList([SoftmaxBlock(i, layers, num_heads, vocab_size, input_dims, hidden_dims) for i in range(layers)]) |
| |
|
| | def forward(self, idx): |
| | |
| | x = norm(self.emb(idx)) |
| | |
| | for i, block in enumerate(self.blocks): |
| | x = block(x) |
| |
|
| | x = norm(x) |
| | |
| | logits = F.linear(x, self.emb.weight) |
| | |
| | return logits |