Spaces:
Running
Running
| """Nanochat model implementation and inference utilities.""" | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import pickle | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| if TYPE_CHECKING: | |
| from collections.abc import Generator | |
| class GPTConfig: | |
| """Configuration for GPT model architecture. | |
| Attributes: | |
| sequence_len: Maximum sequence length | |
| vocab_size: Size of vocabulary | |
| n_layer: Number of transformer layers | |
| n_head: Number of attention heads | |
| n_kv_head: Number of key-value heads | |
| n_embd: Embedding dimension | |
| """ | |
| sequence_len: int = 1024 | |
| vocab_size: int = 50304 | |
| n_layer: int = 12 | |
| n_head: int = 6 | |
| n_kv_head: int = 6 | |
| n_embd: int = 768 | |
| def norm(x: torch.Tensor) -> torch.Tensor: | |
| """Apply RMS normalization to input tensor.""" | |
| return F.rms_norm(x, (x.size(-1),)) | |
| _EXPECTED_NDIM = 4 | |
| def apply_rotary_emb( | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Apply rotary positional embeddings to input tensor. | |
| Args: | |
| x: Input tensor of shape (batch, seq_len, n_heads, head_dim) | |
| cos: Cosine component of rotary embeddings | |
| sin: Sine component of rotary embeddings | |
| Returns: | |
| Tensor with rotary embeddings applied | |
| """ | |
| assert x.ndim == _EXPECTED_NDIM | |
| d = x.shape[3] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| return torch.cat([y1, y2], 3).to(x.dtype) | |
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """Repeat key/value tensors for multi-head attention. | |
| Args: | |
| x: Input tensor of shape (batch, n_kv_heads, seq_len, head_dim) | |
| n_rep: Number of times to repeat | |
| Returns: | |
| Tensor with repeated key/value heads | |
| """ | |
| if n_rep == 1: | |
| return x | |
| bs, n_kv_heads, slen, head_dim = x.shape | |
| return ( | |
| x[:, :, None, :, :] | |
| .expand(bs, n_kv_heads, n_rep, slen, head_dim) | |
| .reshape(bs, n_kv_heads * n_rep, slen, head_dim) | |
| ) | |
| class CausalSelfAttention(nn.Module): | |
| """Causal self-attention with rotary position embeddings.""" | |
| def __init__(self, config: GPTConfig, layer_idx: int) -> None: | |
| """Initialize attention layer. | |
| Args: | |
| config: Model configuration | |
| layer_idx: Layer index for KV cache | |
| """ | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.n_head = config.n_head | |
| self.n_kv_head = config.n_kv_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = self.n_embd // self.n_head | |
| assert self.n_embd % self.n_head == 0 | |
| assert self.n_kv_head <= self.n_head | |
| assert self.n_head % self.n_kv_head == 0 | |
| self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) | |
| self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) | |
| self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) | |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cos_sin: tuple[torch.Tensor, torch.Tensor], | |
| kv_cache: object | None, | |
| ) -> torch.Tensor: | |
| """Forward pass of attention layer. | |
| Args: | |
| x: Input tensor | |
| cos_sin: Tuple of (cos, sin) rotary embeddings | |
| kv_cache: Optional KV cache for generation | |
| Returns: | |
| Output tensor after attention | |
| """ | |
| b, t, _c = x.size() | |
| q = self.c_q(x).view(b, t, self.n_head, self.head_dim) | |
| k = self.c_k(x).view(b, t, self.n_kv_head, self.head_dim) | |
| v = self.c_v(x).view(b, t, self.n_kv_head, self.head_dim) | |
| cos, sin = cos_sin | |
| q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) | |
| q, k = norm(q), norm(k) | |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
| if kv_cache is not None: | |
| k, v = kv_cache.insert_kv(self.layer_idx, k, v) | |
| tq = q.size(2) | |
| tk = k.size(2) | |
| nrep = self.n_head // self.n_kv_head | |
| k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) | |
| if kv_cache is None or tq == tk: | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| elif tq == 1: | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=False) | |
| else: | |
| attn_mask = torch.zeros((tq, tk), dtype=torch.bool, device=q.device) | |
| prefix_len = tk - tq | |
| if prefix_len > 0: | |
| attn_mask[:, :prefix_len] = True | |
| attn_mask[:, prefix_len:] = torch.tril( | |
| torch.ones((tq, tq), dtype=torch.bool, device=q.device), | |
| ) | |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) | |
| y = y.transpose(1, 2).contiguous().view(b, t, -1) | |
| return self.c_proj(y) | |
| class MLP(nn.Module): | |
| """Multi-layer perceptron with squared ReLU activation.""" | |
| def __init__(self, config: GPTConfig) -> None: | |
| """Initialize MLP layer. | |
| Args: | |
| config: Model configuration | |
| """ | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass of MLP. | |
| Args: | |
| x: Input tensor | |
| Returns: | |
| Output tensor after MLP transformation | |
| """ | |
| x = self.c_fc(x) | |
| x = F.relu(x).square() | |
| return self.c_proj(x) | |
| class Block(nn.Module): | |
| """Transformer block with attention and MLP.""" | |
| def __init__(self, config: GPTConfig, layer_idx: int) -> None: | |
| """Initialize transformer block. | |
| Args: | |
| config: Model configuration | |
| layer_idx: Layer index | |
| """ | |
| super().__init__() | |
| self.attn = CausalSelfAttention(config, layer_idx) | |
| self.mlp = MLP(config) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cos_sin: tuple[torch.Tensor, torch.Tensor], | |
| kv_cache: object | None, | |
| ) -> torch.Tensor: | |
| """Forward pass of transformer block. | |
| Args: | |
| x: Input tensor | |
| cos_sin: Tuple of (cos, sin) rotary embeddings | |
| kv_cache: Optional KV cache for generation | |
| Returns: | |
| Output tensor after block transformation | |
| """ | |
| x = x + self.attn(norm(x), cos_sin, kv_cache) | |
| return x + self.mlp(norm(x)) | |
| class GPT(nn.Module): | |
| """GPT model with rotary position embeddings.""" | |
| def __init__(self, config: GPTConfig) -> None: | |
| """Initialize GPT model. | |
| Args: | |
| config: Model configuration | |
| """ | |
| super().__init__() | |
| self.config = config | |
| self.transformer = nn.ModuleDict( | |
| { | |
| "wte": nn.Embedding(config.vocab_size, config.n_embd), | |
| "h": nn.ModuleList( | |
| [Block(config, layer_idx) for layer_idx in range(config.n_layer)], | |
| ), | |
| }, | |
| ) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| self.rotary_seq_len = config.sequence_len * 10 | |
| head_dim = config.n_embd // config.n_head | |
| cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) | |
| self.register_buffer("cos", cos, persistent=False) | |
| self.register_buffer("sin", sin, persistent=False) | |
| self.transformer.wte.to(dtype=torch.bfloat16) | |
| def init_weights(self) -> None: | |
| """Initialize model weights.""" | |
| self.apply(self._init_weights) | |
| torch.nn.init.zeros_(self.lm_head.weight) | |
| for block in self.transformer.h: | |
| torch.nn.init.zeros_(block.mlp.c_proj.weight) | |
| torch.nn.init.zeros_(block.attn.c_proj.weight) | |
| head_dim = self.config.n_embd // self.config.n_head | |
| cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) | |
| self.cos, self.sin = cos, sin | |
| def _init_weights(self, module: nn.Module) -> None: | |
| """Initialize weights for a single module. | |
| Args: | |
| module: Module to initialize | |
| """ | |
| if isinstance(module, nn.Linear): | |
| fan_out = module.weight.size(0) | |
| fan_in = module.weight.size(1) | |
| std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=std) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) | |
| def _precompute_rotary_embeddings( | |
| self, | |
| seq_len: int, | |
| head_dim: int, | |
| base: int = 10000, | |
| device: torch.device | str | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Precompute rotary position embeddings. | |
| Args: | |
| seq_len: Maximum sequence length | |
| head_dim: Dimension of attention heads | |
| base: Base for frequency calculation | |
| device: Device to place tensors on | |
| Returns: | |
| Tuple of (cos, sin) tensors for rotary embeddings | |
| """ | |
| if device is None: | |
| device = self.transformer.wte.weight.device | |
| channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) | |
| inv_freq = 1.0 / (base ** (channel_range / head_dim)) | |
| t = torch.arange(seq_len, dtype=torch.float32, device=device) | |
| freqs = torch.outer(t, inv_freq) | |
| cos, sin = freqs.cos(), freqs.sin() | |
| cos, sin = cos.bfloat16(), sin.bfloat16() | |
| return cos[None, :, None, :], sin[None, :, None, :] | |
| def forward( | |
| self, | |
| idx: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| kv_cache: object | None = None, | |
| ) -> torch.Tensor: | |
| """Forward pass of GPT model. | |
| Args: | |
| idx: Input token indices | |
| targets: Target token indices (unused in this implementation) | |
| kv_cache: Optional KV cache for generation | |
| Returns: | |
| Logits for next token prediction | |
| """ | |
| _b, t = idx.size() | |
| assert self.cos.size(1) >= t | |
| t0 = 0 if kv_cache is None else kv_cache.get_pos() | |
| cos_sin = self.cos[:, t0 : t0 + t], self.sin[:, t0 : t0 + t] | |
| x = self.transformer.wte(idx) | |
| x = norm(x) | |
| for block in self.transformer.h: | |
| x = block(x, cos_sin, kv_cache) | |
| x = norm(x) | |
| softcap = 15 | |
| logits = self.lm_head(x) | |
| return softcap * torch.tanh(logits / softcap) | |
| class NanochatModel: | |
| """Wrapper class for loading and running inference with the nanochat model.""" | |
| def __init__(self, model_dir: str, device: str = "cpu") -> None: | |
| """Initialize the NanochatModel. | |
| Args: | |
| model_dir: Directory containing model files | |
| device: Device to run inference on (default: "cpu") | |
| """ | |
| self.device = torch.device(device) | |
| self.model_dir = model_dir | |
| self.model = self._load_model() | |
| self.enc = self._load_tokenizer() | |
| self._setup_special_tokens() | |
| def _load_model(self) -> GPT: | |
| """Load the model from the model directory.""" | |
| model_dir_path = Path(self.model_dir) | |
| model_files = list(model_dir_path.glob("model_*.pt")) | |
| if not model_files: | |
| msg = f"No model files found in {self.model_dir}" | |
| raise FileNotFoundError(msg) | |
| model_file = model_files[0] | |
| meta_files = list(model_dir_path.glob("meta_*.json")) | |
| if not meta_files: | |
| msg = f"No meta files found in {self.model_dir}" | |
| raise FileNotFoundError(msg) | |
| meta_file = meta_files[0] | |
| with meta_file.open() as f: | |
| meta = json.load(f) | |
| model_config_kwargs = meta["model_config"] | |
| model_config = GPTConfig(**model_config_kwargs) | |
| with torch.device("meta"): | |
| model = GPT(model_config) | |
| model_data = torch.load( | |
| model_file, | |
| map_location=self.device, | |
| weights_only=True, | |
| ) | |
| model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()} | |
| model_data = { | |
| k: v.float() if v.dtype == torch.bfloat16 else v | |
| for k, v in model_data.items() | |
| } | |
| model.to_empty(device=self.device) | |
| model.init_weights() | |
| model.load_state_dict(model_data, strict=True, assign=True) | |
| model.eval() | |
| return model | |
| def _load_tokenizer(self) -> object: | |
| """Load the tokenizer from the model directory. | |
| Returns: | |
| Loaded tokenizer object | |
| """ | |
| tokenizer_path = Path(self.model_dir) / "tokenizer.pkl" | |
| if not tokenizer_path.exists(): | |
| msg = f"Tokenizer not found at {tokenizer_path}" | |
| raise FileNotFoundError(msg) | |
| with tokenizer_path.open("rb") as f: | |
| return pickle.load(f) | |
| def _setup_special_tokens(self) -> None: | |
| """Set up special token IDs for chat formatting.""" | |
| try: | |
| try: | |
| self.bos_token_id = self.enc.encode_single_token("<|bos|>") | |
| except KeyError: | |
| self.bos_token_id = self.enc.encode_single_token("<|endoftext|>") | |
| self.user_start_id = self.enc.encode_single_token("<|user_start|>") | |
| self.user_end_id = self.enc.encode_single_token("<|user_end|>") | |
| self.assistant_start_id = self.enc.encode_single_token( | |
| "<|assistant_start|>", | |
| ) | |
| self.assistant_end_id = self.enc.encode_single_token("<|assistant_end|>") | |
| self.stop_tokens = {self.bos_token_id, self.assistant_end_id} | |
| except KeyError as e: | |
| msg = f"Required special token missing from tokenizer: {e}" | |
| raise ValueError(msg) from e | |
| def format_prompt(self, message: str) -> list[int]: | |
| """Format a user message using chat format. | |
| Args: | |
| message: User's input message | |
| Returns: | |
| List of token IDs formatted for chat | |
| """ | |
| prompt_tokens = self.enc.encode_ordinary(message) | |
| return [ | |
| self.bos_token_id, | |
| self.user_start_id, | |
| *prompt_tokens, | |
| self.user_end_id, | |
| self.assistant_start_id, | |
| ] | |
| def format_conversation(self, history: list[dict[str, str]]) -> list[int]: | |
| """Format a multi-turn conversation using chat format. | |
| Args: | |
| history: List of message dictionaries with 'role' and 'content' keys | |
| role can be 'user' or 'assistant' | |
| Returns: | |
| List of token IDs formatted for multi-turn chat | |
| """ | |
| tokens = [self.bos_token_id] | |
| for message in history: | |
| role = message.get("role") | |
| content = message.get("content", "") | |
| content_tokens = self.enc.encode_ordinary(content) | |
| if role == "user": | |
| tokens.extend([ | |
| self.user_start_id, | |
| *content_tokens, | |
| self.user_end_id, | |
| ]) | |
| elif role == "assistant": | |
| tokens.extend([ | |
| self.assistant_start_id, | |
| *content_tokens, | |
| self.assistant_end_id, | |
| ]) | |
| tokens.append(self.assistant_start_id) | |
| return tokens | |
| def generate( | |
| self, | |
| prompt: str | None = None, | |
| history: list[dict[str, str]] | None = None, | |
| max_tokens: int = 512, | |
| temperature: float = 0.8, | |
| top_k: int = 50, | |
| ) -> Generator[str, None, None]: | |
| """Generate text from a prompt or conversation history. | |
| Args: | |
| prompt: The input text prompt (for single-turn) | |
| history: List of message dicts with 'role' and 'content' (for multi-turn) | |
| max_tokens: Maximum number of tokens to generate | |
| temperature: Sampling temperature | |
| top_k: Top-k sampling parameter | |
| Yields: | |
| Decoded token strings | |
| """ | |
| if history is not None: | |
| input_ids = self.format_conversation(history) | |
| elif prompt is not None: | |
| input_ids = self.format_prompt(prompt) | |
| else: | |
| msg = "Either prompt or history must be provided" | |
| raise ValueError(msg) | |
| x = torch.tensor([input_ids], dtype=torch.long, device=self.device) | |
| with torch.inference_mode(): | |
| for _ in range(max_tokens): | |
| logits = self.model(x) | |
| logits = logits[:, -1, :] | |
| logits = logits / temperature | |
| if top_k > 0: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("inf") | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| if next_token.item() in self.stop_tokens: | |
| break | |
| token_str = self.enc.decode([next_token.item()]) | |
| yield token_str | |
| x = torch.cat([x, next_token], dim=1) | |