| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from dataclasses import dataclass |
| | from functools import lru_cache |
| | from pathlib import Path |
| | from typing import TYPE_CHECKING, Annotated, Callable, Optional |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| | from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding |
| | from nemo.collections.llm.gpt.model.llama import Llama3Config, LlamaModel |
| | from nemo.collections.llm.utils import Config |
| | from nemo.lightning import OptimizerModule, io |
| | from nemo.lightning.base import teardown |
| | from torch import Tensor, nn |
| |
|
| | from .log import log |
| |
|
| |
|
| | class RotaryEmbedding3D(RotaryEmbedding): |
| | """Rotary Embedding3D for Cosmos Language model. |
| | Args: |
| | kv_channels (int): Projection weights dimension in multi-head attention. Obtained |
| | from transformer config |
| | rotary_base (int, optional): Base period for rotary position embeddings. Defaults to |
| | 10000. |
| | use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly |
| | on the GPU. Defaults to False |
| | latent_shape: The shape of the latents produced by the video after being tokenized |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | seq_len: int, |
| | kv_channels: int, |
| | training_type: str = None, |
| | rotary_base: int = 10000, |
| | use_cpu_initialization: bool = False, |
| | latent_shape=[5, 40, 64], |
| | apply_yarn=False, |
| | original_latent_shape=None, |
| | beta_fast=32, |
| | beta_slow=1, |
| | scale=None, |
| | max_position_embeddings=None, |
| | original_max_position_embeddings=None, |
| | extrapolation_factor=1, |
| | attn_factor=1, |
| | ) -> None: |
| | super().__init__( |
| | kv_channels=kv_channels, |
| | rotary_base=rotary_base, |
| | rotary_percent=1.0, |
| | use_cpu_initialization=use_cpu_initialization, |
| | ) |
| | self.latent_shape = latent_shape |
| | self.device = "cpu" if use_cpu_initialization else torch.cuda.current_device() |
| | self.dim = kv_channels |
| | self.rope_theta = rotary_base |
| | self.apply_yarn = apply_yarn |
| | self.original_latent_shape = original_latent_shape |
| | self.beta_fast = beta_fast |
| | self.beta_slow = beta_slow |
| | self.scale = scale |
| | self.max_position_embeddings = max_position_embeddings |
| | self.original_max_position_embeddings = original_max_position_embeddings |
| | self.attn_factor = attn_factor |
| | dim_h = self.dim // 6 * 2 |
| | dim_t = self.dim - 2 * dim_h |
| | self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.device) / dim_h |
| | spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) |
| | self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.device) / dim_t |
| | temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) |
| | if self.apply_yarn: |
| | assert self.original_latent_shape is not None, "Original latent shape required." |
| | assert self.beta_slow is not None, "Beta slow value required." |
| | assert self.beta_fast is not None, "Beta fast value required." |
| | scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) |
| | spatial_inv_freq = spatial_inv_freq * scale_factors_spatial |
| | scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) |
| | temporal_inv_freq = temporal_inv_freq * scale_factors_temporal |
| | self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) |
| | self.spatial_inv_freq = spatial_inv_freq |
| | self.temporal_inv_freq = temporal_inv_freq |
| | max_seq_len_cached = max(self.latent_shape) |
| | if self.apply_yarn and seq_len > max_seq_len_cached: |
| | max_seq_len_cached = seq_len |
| | self.max_seq_len_cached = max_seq_len_cached |
| | self.freqs = self.get_freqs_non_repeated(self.max_seq_len_cached) |
| |
|
| | def get_mscale(self, scale: float = 1.0) -> float: |
| | """Get the magnitude scaling factor for YaRN.""" |
| | if scale <= 1: |
| | return 1.0 |
| | return 0.1 * math.log(scale) + 1.0 |
| |
|
| | def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: |
| | """Get the scale factors for YaRN.""" |
| | |
| | |
| | high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len |
| | low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len |
| | |
| | |
| | smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) |
| | |
| | scale_factors = (1 - smooth_mask) / self.scale + smooth_mask |
| | return scale_factors |
| |
|
| | def get_freqs_non_repeated(self, max_seq_len_cached: int, offset: int = 0) -> Tensor: |
| | dtype = self.spatial_inv_freq.dtype |
| | device = self.spatial_inv_freq.device |
| |
|
| | self.seq = (torch.arange(max_seq_len_cached, device=device, dtype=dtype) + offset).cuda() |
| |
|
| | assert hasattr( |
| | self, "latent_shape" |
| | ), "Latent shape is not set. Please run set_latent_shape() method on rope embedding. " |
| | T, H, W = self.latent_shape |
| | half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq.cuda()) |
| | half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq.cuda()) |
| | half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq.cuda()) |
| | emb = torch.cat( |
| | [ |
| | repeat(half_emb_t, "t d -> t h w d", h=H, w=W), |
| | repeat(half_emb_h, "h d -> t h w d", t=T, w=W), |
| | repeat(half_emb_w, "w d -> t h w d", t=T, h=H), |
| | ] |
| | * 2, |
| | dim=-1, |
| | ) |
| | emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() |
| | return emb |
| |
|
| | @lru_cache(maxsize=32) |
| | def forward(self, seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: |
| | if self.spatial_inv_freq.device.type == "cpu": |
| | |
| | self.spatial_inv_freq = self.spatial_inv_freq.to(device=torch.cuda.current_device()) |
| |
|
| | max_seq_len_cached = self.max_seq_len_cached |
| | if self.apply_yarn and seq_len > max_seq_len_cached: |
| | max_seq_len_cached = seq_len |
| | self.max_seq_len_cached = max_seq_len_cached |
| | emb = self.get_freqs_non_repeated(self.max_seq_len_cached) |
| | return emb |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
| |
|
| |
|
| | @dataclass |
| | class CosmosConfig(Llama3Config): |
| | qk_layernorm: bool = True |
| | rope_dim: str = "3D" |
| | vocab_size: int = 64000 |
| | activation_func = F.silu |
| |
|
| | def configure_model(self, tokenizer) -> "MCoreGPTModel": |
| | model = super().configure_model(tokenizer) |
| | if self.rope_dim == "3D": |
| | model.rotary_pos_emb = RotaryEmbedding3D( |
| | seq_len=self.seq_length, |
| | training_type=None, |
| | kv_channels=self.kv_channels, |
| | max_position_embeddings=self.seq_length, |
| | original_max_position_embeddings=self.original_seq_len if hasattr(self, "original_seq_len") else None, |
| | rotary_base=self.rotary_base, |
| | apply_yarn=True if hasattr(self, "apply_yarn") else False, |
| | scale=self.yarn_scale if hasattr(self, "yarn_scale") else None, |
| | extrapolation_factor=1, |
| | attn_factor=1, |
| | beta_fast=self.yarn_beta_fast if hasattr(self, "yarn_beta_fast") else 32, |
| | beta_slow=self.yarn_beta_slow if hasattr(self, "yarn_beta_slow") else 1, |
| | latent_shape=[5, 40, 64], |
| | original_latent_shape=self.original_latent_shape if hasattr(self, "original_latent_shape") else None, |
| | ) |
| | return model |
| |
|
| |
|
| | @dataclass |
| | class CosmosConfig4B(CosmosConfig): |
| | rotary_base: int = 500_000 |
| | seq_length: int = 15360 |
| | num_layers: int = 16 |
| | hidden_size: int = 4096 |
| | ffn_hidden_size: int = 14336 |
| | num_attention_heads: int = 32 |
| | num_query_groups: int = 8 |
| | layernorm_epsilon: float = 1e-5 |
| | use_cpu_initialization: bool = True |
| | make_vocab_size_divisible_by: int = 128 |
| | kv_channels: int = 128 |
| |
|
| |
|
| | @dataclass |
| | class CosmosConfig12B(CosmosConfig): |
| | rotary_base: int = 500_000 |
| | seq_length: int = 15360 |
| | num_layers: int = 40 |
| | hidden_size: int = 5120 |
| | ffn_hidden_size: int = 14336 |
| | num_attention_heads: int = 32 |
| | num_query_groups: int = 8 |
| | layernorm_epsilon: float = 1e-5 |
| | use_cpu_initialization: bool = True |
| | make_vocab_size_divisible_by: int = 128 |
| | kv_channels: int = 128 |
| | original_latent_shape = [3, 40, 64] |
| | apply_yarn: bool = True |
| | yarn_beta_fast: int = 4 |
| | yarn_beta_slow: int = 1 |
| | yarn_scale: int = 2 |
| | original_seq_len = 8192 |
| |
|
| |
|
| | class CosmosModel(LlamaModel): |
| | def __init__( |
| | self, |
| | config: Annotated[Optional[CosmosConfig], Config[CosmosConfig]] = None, |
| | optim: Optional[OptimizerModule] = None, |
| | tokenizer: Optional["TokenizerSpec"] = None, |
| | model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, |
| | ): |
| | super().__init__(config or CosmosConfig4B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) |
| | self.config = config |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "model.layers.*.feed_forward.w1.weight", |
| | "model.layers.*.feed_forward.w3.weight", |
| | ), |
| | target_key="decoder.layers.*.mlp.linear_fc1.weight", |
| | ) |
| | def _mlp_glu(ctx: io.TransformCTX, w1, w3): |
| | return torch.cat((w1, w3), axis=0) |
| |
|
| |
|
| | @io.state_transform( |
| | source_key=( |
| | "model.layers.*.attention.wq.weight", |
| | "model.layers.*.attention.wk.weight", |
| | "model.layers.*.attention.wv.weight", |
| | ), |
| | target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
| | ) |
| | def _import_qkv_cosmos(ctx: io.TransformCTX, q, k, v): |
| | megatron_config = ctx.target.config |
| |
|
| | head_num = megatron_config.num_attention_heads |
| | num_query_groups = megatron_config.num_query_groups |
| | heads_per_group = head_num // num_query_groups |
| | hidden_size = megatron_config.hidden_size |
| | head_size = megatron_config.kv_channels |
| |
|
| | old_tensor_shape = q.size() |
| | new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] |
| | new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] |
| |
|
| | q = q.view(*new_q_tensor_shape) |
| | k = k.view(*new_kv_tensor_shape) |
| | v = v.view(*new_kv_tensor_shape) |
| |
|
| | qkv_weights_l = [] |
| | for i in range(num_query_groups): |
| | qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) |
| | qkv_weights_l.append(k[i : i + 1, :, :]) |
| | qkv_weights_l.append(v[i : i + 1, :, :]) |
| | qkv_weights = torch.cat(qkv_weights_l) |
| | assert qkv_weights.ndim == 3, qkv_weights.shape |
| | assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape |
| | assert qkv_weights.shape[1] == head_size, qkv_weights.shape |
| | assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape |
| |
|
| | qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) |
| |
|
| | return qkv_weights |
| |
|
| |
|
| | @io.model_importer(CosmosModel, "pt") |
| | class PTCosmosImporter(io.ModelConnector["PTCosmosModel", CosmosModel]): |
| | def init(self) -> CosmosModel: |
| | return CosmosModel(self.config, tokenizer=self.tokenizer) |
| |
|
| | def apply(self, output_path: Path) -> Path: |
| | pt_model_path = str(self) |
| | cosmos_model_state_dict = torch.load(pt_model_path, map_location="cpu") |
| | for k, v in cosmos_model_state_dict.items(): |
| | |
| | cosmos_model_state_dict[k] = v.float() |
| |
|
| | |
| | class WrapperCosmos: |
| | def __init__(self, model_state_dict): |
| | self.model_state_dict = model_state_dict |
| |
|
| | def state_dict(self): |
| | return self.model_state_dict |
| |
|
| | source = WrapperCosmos(cosmos_model_state_dict) |
| | target = self.init() |
| | trainer = self.nemo_setup(target) |
| | self.convert_state(source, target) |
| | self.nemo_save(output_path, trainer) |
| |
|
| | log.info(f"Converted PT Cosmos model to Nemo, model saved to {output_path}") |
| |
|
| | teardown(trainer, target) |
| | del trainer, target |
| |
|
| | return output_path |
| |
|
| | def convert_state(self, source, target): |
| | mapping = { |
| | "model.tok_embeddings.weight": "embedding.word_embeddings.weight", |
| | "model.layers.*.attention.wo.weight": "decoder.layers.*.self_attention.linear_proj.weight", |
| | "model.layers.*.attention.q_norm.weight": "decoder.layers.*.self_attention.q_layernorm.weight", |
| | "model.layers.*.attention.k_norm.weight": "decoder.layers.*.self_attention.k_layernorm.weight", |
| | "model.layers.*.attention_norm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", |
| | "model.layers.*.feed_forward.w2.weight": "decoder.layers.*.mlp.linear_fc2.weight", |
| | "model.layers.*.ffn_norm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", |
| | "model.norm.weight": "decoder.final_layernorm.weight", |
| | "model.output.weight": "output_layer.weight", |
| | } |
| |
|
| | return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv_cosmos, _mlp_glu]) |
| |
|
| | @property |
| | def tokenizer(self): |
| | return None |
| |
|
| | @property |
| | def config(self): |
| | if "4B" in str(self) or "4b" in str(self): |
| | return CosmosConfig4B() |
| | elif "12B" in str(self) or "12b" in str(self): |
| | return CosmosConfig12B() |
| | else: |
| | raise ValueError("Unable to infer model size from checkpoint") |
| |
|