Instructions to use rhymes-ai/Aria-sequential_mlp with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use rhymes-ai/Aria-sequential_mlp with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="rhymes-ai/Aria-sequential_mlp", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] pipe(text=messages)# Load model directly from transformers import AutoProcessor, AutoModelForImageTextToText processor = AutoProcessor.from_pretrained("rhymes-ai/Aria-sequential_mlp", trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained("rhymes-ai/Aria-sequential_mlp", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(processor.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use rhymes-ai/Aria-sequential_mlp with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "rhymes-ai/Aria-sequential_mlp" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "rhymes-ai/Aria-sequential_mlp", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker
docker model run hf.co/rhymes-ai/Aria-sequential_mlp
- SGLang
How to use rhymes-ai/Aria-sequential_mlp with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "rhymes-ai/Aria-sequential_mlp" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "rhymes-ai/Aria-sequential_mlp", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "rhymes-ai/Aria-sequential_mlp" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "rhymes-ai/Aria-sequential_mlp", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }' - Docker Model Runner
How to use rhymes-ai/Aria-sequential_mlp with Docker Model Runner:
docker model run hf.co/rhymes-ai/Aria-sequential_mlp
| # Copyright 2024 Rhymes AI. All rights reserved. | |
| # | |
| # Licensed to the Apache Software Foundation (ASF) under one | |
| # or more contributor license agreements. See the NOTICE file | |
| # distributed with this work for additional information | |
| # regarding copyright ownership. The ASF licenses this file | |
| # to you under the Apache License, Version 2.0 (the | |
| # "License"); you may not use this file except in compliance | |
| # with the License. You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an | |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
| # KIND, either express or implied. See the License for the | |
| # specific language governing permissions and limitations | |
| # under the License. | |
| import logging | |
| import os | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from transformers import LlamaConfig | |
| from transformers.models.llama.modeling_llama import ( | |
| ACT2FN, | |
| LLAMA_ATTENTION_CLASSES, | |
| LlamaDecoderLayer, | |
| LlamaForCausalLM, | |
| LlamaMLP, | |
| LlamaModel, | |
| LlamaRMSNorm, | |
| LlamaRotaryEmbedding, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class AriaMoELMConfig(LlamaConfig): | |
| """ | |
| Configuration class for AriaMoE language model. | |
| This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture. | |
| """ | |
| model_type = "aria_moe_lm" | |
| def __init__( | |
| self, | |
| moe_intermediate_size: int = 4096, | |
| moe_num_experts: int = 8, | |
| moe_topk: int = 2, | |
| moe_z_loss_coeff: float = 1e-5, | |
| moe_aux_loss_coeff: float = 1e-3, | |
| moe_num_shared_experts: int = 2, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the AriaMoELMConfig. | |
| Args: | |
| moe_intermediate_size (int): The intermediate size for MoE layers. Default is 4096. | |
| moe_num_experts (int): The number of experts in the MoE layer. Default is 8. | |
| moe_topk (int): The number of top experts to route to for each token. Default is 2. | |
| moe_z_loss_coeff (float): The coefficient for the auxiliary z-loss. Default is 1e-5. | |
| moe_aux_loss_coeff (float): The coefficient for the auxiliary load balancing loss. Default is 1e-3. | |
| moe_num_shared_experts (int): The number of shared experts. Default is 2. | |
| **kwargs: Additional keyword arguments to be passed to the parent LlamaConfig. | |
| """ | |
| super().__init__(**kwargs) | |
| self.moe_intermediate_size = moe_intermediate_size | |
| self.moe_num_experts = moe_num_experts | |
| self.moe_topk = moe_topk | |
| self.moe_z_loss_coeff = moe_z_loss_coeff | |
| self.moe_aux_loss_coeff = moe_aux_loss_coeff | |
| self.moe_num_shared_experts = moe_num_shared_experts | |
| # copied from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/moe_utils.py#L101-L142 | |
| class MoEAuxLossAutoScaler(torch.autograd.Function): | |
| """An AutoScaler that compute and scales the grad for auxiliary loss.""" | |
| main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) | |
| def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): | |
| """Preserve the aux_loss by storing it in the context to avoid garbage collection. | |
| Args: | |
| output (torch.Tensor): The output tensor. | |
| aux_loss (torch.Tensor): The auxiliary loss tensor. | |
| Returns: | |
| torch.Tensor: The output tensor. | |
| """ | |
| ctx.save_for_backward(aux_loss) | |
| return output | |
| def backward(ctx, grad_output: torch.Tensor): | |
| """Compute and scale the gradient for auxiliary loss.. | |
| Args: | |
| grad_output (torch.Tensor): The gradient of the output. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient. | |
| """ | |
| (aux_loss,) = ctx.saved_tensors | |
| aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale | |
| scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale | |
| return grad_output, scaled_aux_loss_grad | |
| def set_loss_scale(scale: torch.Tensor): | |
| """set the scale of the aux loss. | |
| Args: | |
| scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss. | |
| """ | |
| MoEAuxLossAutoScaler.main_loss_backward_scale = scale | |
| def z_loss_func(logits, z_loss_coeff): | |
| """Encourages the router's logits to remain small to enhance stability. | |
| Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. | |
| Args: | |
| logits (torch.Tensor): The logits of the router. | |
| Returns: | |
| torch.Tensor: The logits after applying the z-loss. | |
| """ | |
| z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff | |
| return z_loss | |
| def switch_load_balancing_loss_func( | |
| probs: torch.Tensor, | |
| tokens_per_expert: torch.Tensor, | |
| topk: int, | |
| moe_aux_loss_coeff: float, | |
| ): | |
| """Calculate the auxiliary loss for better load balacing. | |
| Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details. | |
| Args: | |
| probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts] | |
| tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts] | |
| Returns: | |
| torch.Tensor: The auxiliary loss for load balancing. | |
| """ | |
| num_tokens = probs.shape[0] * topk | |
| num_experts = probs.shape[1] | |
| probs_mean_per_expert = probs.mean(dim=0) | |
| aux_loss = torch.sum(probs_mean_per_expert * tokens_per_expert) * ( | |
| num_experts / num_tokens * moe_aux_loss_coeff | |
| ) | |
| return aux_loss | |
| # adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/router.py#L96-L304 | |
| class TopKRouter(nn.Module): | |
| """ | |
| Top-K Router for Mixture of Experts (MoE) models. | |
| This router determines which experts should process each token based on the top-k scoring experts. | |
| It also applies auxiliary losses to encourage load balancing among experts. | |
| Args: | |
| config (AriaMoELMConfig): Configuration object containing MoE-related parameters. | |
| """ | |
| def __init__(self, config: AriaMoELMConfig): | |
| super().__init__() | |
| self.config = config | |
| self.weight = nn.Parameter( | |
| torch.empty((self.config.moe_num_experts, self.config.hidden_size)) | |
| ) | |
| # FIXME: initialize the weight | |
| def gating(self, input: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute the gating logits for each token-expert pair. | |
| Args: | |
| input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. | |
| Returns: | |
| torch.Tensor: Logits tensor of shape [batch_size * seq_len, num_experts]. | |
| """ | |
| logits = torch.nn.functional.linear(input, self.weight) | |
| return logits | |
| def apply_z_loss(self, logits: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Apply z-loss to encourage router logits to remain small for enhanced stability. | |
| Args: | |
| logits (torch.Tensor): Router logits. | |
| Returns: | |
| torch.Tensor: Logits with z-loss applied. | |
| """ | |
| z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff) | |
| logits = MoEAuxLossAutoScaler.apply(logits, z_loss) | |
| return logits | |
| def apply_aux_loss( | |
| self, | |
| logits: torch.Tensor, | |
| tokens_per_expert: torch.Tensor, | |
| activation: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply auxiliary loss for load balancing among experts. | |
| Args: | |
| logits (torch.Tensor): Router logits. | |
| tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. | |
| activation (torch.Tensor): Activation values. | |
| Returns: | |
| torch.Tensor: Activation with auxiliary loss applied. | |
| """ | |
| probs = torch.softmax(logits, dim=-1, dtype=torch.float32) | |
| aux_loss = switch_load_balancing_loss_func( | |
| probs, | |
| tokens_per_expert, | |
| self.config.moe_topk, | |
| self.config.moe_aux_loss_coeff, | |
| ) | |
| return MoEAuxLossAutoScaler.apply(activation, aux_loss) | |
| def routing( | |
| self, logits: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Perform the routing operation to determine expert assignments. | |
| Args: | |
| logits (torch.Tensor): Router logits. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| - scores: Softmax probabilities for top-k experts. | |
| - top_indices: Indices of top-k experts for each token. | |
| - tokens_per_expert: Number of tokens assigned to each expert. | |
| """ | |
| logits = self.apply_z_loss(logits) | |
| top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) | |
| scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits) | |
| tokens_per_expert = torch.histc( | |
| top_indices.flatten(), | |
| bins=self.config.moe_num_experts, | |
| min=0, | |
| max=self.config.moe_num_experts - 1, | |
| ) | |
| scores = self.apply_aux_loss(logits, tokens_per_expert, scores) | |
| return scores, top_indices, tokens_per_expert | |
| def forward( | |
| self, input: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass of the TopKRouter. | |
| Args: | |
| input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| - scores: Softmax probabilities for top-k experts. | |
| - top_indices: Indices of top-k experts for each token. | |
| - tokens_per_expert: Number of tokens assigned to each expert. | |
| """ | |
| logits = self.gating(input) | |
| logits = logits.view(-1, self.config.moe_num_experts) | |
| scores, top_indices, tokens_per_expert = self.routing(logits) | |
| return scores, top_indices, tokens_per_expert | |
| # adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587 | |
| class TokenDispatcher: | |
| """ | |
| Handles the dispatching and gathering of tokens to and from experts. | |
| This class is responsible for permuting tokens based on expert assignments and | |
| unpermuting them after expert processing. | |
| Args: | |
| config (AriaMoELMConfig): Configuration object containing MoE-related parameters. | |
| """ | |
| def __init__(self, config: AriaMoELMConfig): | |
| self.config = config | |
| self.hidden_states_shape = None | |
| self.reversed_input_permutation_mapping = None | |
| def token_permutation( | |
| self, hidden_states: torch.Tensor, indices: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Permute tokens based on expert assignments. | |
| Args: | |
| hidden_states (torch.Tensor): Input hidden states. | |
| indices (torch.Tensor): Expert assignment indices. | |
| Returns: | |
| torch.Tensor: Permuted tokens. | |
| """ | |
| self.hidden_states_shape = hidden_states.shape | |
| hidden_states = hidden_states.view(-1, hidden_states.size(-1)) | |
| flatten_indices = indices.flatten() | |
| sorted_indices = torch.argsort(flatten_indices, stable=True) | |
| permuted_tokens = hidden_states.index_select( | |
| 0, sorted_indices // self.config.moe_topk | |
| ) | |
| self.reversed_input_permutation_mapping = sorted_indices | |
| return permuted_tokens | |
| def token_unpermutation( | |
| self, permuted_tokens: torch.Tensor, scores: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Unpermute tokens and combine expert outputs. | |
| Args: | |
| permuted_tokens (torch.Tensor): Tokens after expert processing. | |
| scores (torch.Tensor): Expert assignment scores. | |
| Returns: | |
| torch.Tensor: Unpermuted and combined output. | |
| """ | |
| num_unpermuted_tokens = scores.numel() | |
| unpermuted_tokens = torch.zeros( | |
| (num_unpermuted_tokens, permuted_tokens.size(1)), | |
| dtype=permuted_tokens.dtype, | |
| device=permuted_tokens.device, | |
| ) | |
| unpermuted_tokens.index_copy_( | |
| 0, self.reversed_input_permutation_mapping, permuted_tokens | |
| ) | |
| unpermuted_tokens = unpermuted_tokens.reshape( | |
| -1, self.config.moe_topk, permuted_tokens.size(1) | |
| ) | |
| unpermuted_tokens = unpermuted_tokens * scores.unsqueeze(-1) | |
| unpermuted_tokens = unpermuted_tokens.sum(dim=1).type_as(permuted_tokens) | |
| output = unpermuted_tokens.view(self.hidden_states_shape) | |
| return output | |
| class SharedExpertMLP(LlamaMLP): | |
| """ | |
| Shared Expert MLP for shared experts. | |
| Unlike routed experts, shared experts process all tokens without routing. | |
| This class reconfigures the intermediate size in comparison to the LlamaMLP. | |
| Args: | |
| config (AriaMoELMConfig): Configuration object for the AriaMoE language model. | |
| """ | |
| def __init__(self, config: AriaMoELMConfig): | |
| nn.Module.__init__(self) | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = ( | |
| config.moe_intermediate_size * config.moe_num_shared_experts | |
| ) | |
| self.gate_proj = nn.Linear( | |
| self.hidden_size, self.intermediate_size, bias=config.mlp_bias | |
| ) | |
| self.up_proj = nn.Linear( | |
| self.hidden_size, self.intermediate_size, bias=config.mlp_bias | |
| ) | |
| self.down_proj = nn.Linear( | |
| self.intermediate_size, self.hidden_size, bias=config.mlp_bias | |
| ) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def sequential_gemm(input, weight, tokens_per_expert): | |
| """ | |
| Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. | |
| Args: | |
| input (torch.Tensor): Input tensor of shape (num_tokens, in_features). | |
| weight (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features). | |
| tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. | |
| Returns: | |
| torch.Tensor: Output tensor of shape (num_tokens, out_features). | |
| """ | |
| num_tokens = input.shape[0] | |
| out_features = weight.shape[-1] | |
| output = torch.zeros( | |
| num_tokens, out_features, dtype=input.dtype, device=input.device | |
| ) | |
| cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) | |
| # Insert zero at the begining for offset index's convenience | |
| zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) | |
| cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) | |
| for expert_num in range(weight.shape[0]): | |
| start = cumsum_num_tokens[expert_num] | |
| end = cumsum_num_tokens[expert_num + 1] | |
| tokens = input[start:end] | |
| out = torch.matmul(tokens, weight[expert_num]) | |
| output[start:end] = out | |
| return output | |
| class ExpertMLP(LlamaMLP): | |
| """ | |
| Expert MLP for the Mixture of Experts (MoE) layer. | |
| This class represents an individual expert in the MoE architecture. It's a modified | |
| version of LlamaMLP with a configurable intermediate size specific to MoE. | |
| Args: | |
| config (AriaMoELMConfig): Configuration object for the AriaMoE language model. | |
| """ | |
| def __init__(self, config: AriaMoELMConfig): | |
| nn.Module.__init__(self) | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.moe_intermediate_size | |
| self.gate_proj = nn.Linear( | |
| self.hidden_size, self.intermediate_size, bias=config.mlp_bias | |
| ) | |
| self.up_proj = nn.Linear( | |
| self.hidden_size, self.intermediate_size, bias=config.mlp_bias | |
| ) | |
| self.down_proj = nn.Linear( | |
| self.intermediate_size, self.hidden_size, bias=config.mlp_bias | |
| ) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| class SequentialMLP(nn.Module): | |
| """ | |
| Sequential MLP for handling multiple experts in the Mixture of Experts (MoE) layer. | |
| This class manages a collection of ExpertMLPs and processes tokens through them sequentially. | |
| Args: | |
| config (AriaMoELMConfig): Configuration object for the AriaMoE language model. | |
| """ | |
| def __init__(self, config: AriaMoELMConfig): | |
| super().__init__() | |
| self.config = config | |
| self.experts = nn.ModuleList( | |
| [ExpertMLP(config) for _ in range(config.moe_num_experts)] | |
| ) | |
| def forward(self, permuted_tokens: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass of the SequentialMLP. | |
| This method processes the permuted tokens through each expert sequentially, | |
| based on the number of tokens assigned to each expert. | |
| Args: | |
| permuted_tokens (torch.Tensor): Permuted input tokens. | |
| tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. | |
| Returns: | |
| torch.Tensor: Processed output from all experts. | |
| """ | |
| output = torch.zeros_like(permuted_tokens) | |
| cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) | |
| # Insert zero at the beginning for offset index's convenience | |
| zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) | |
| cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) | |
| for expert_num, expert in enumerate(self.experts): | |
| start = cumsum_num_tokens[expert_num] | |
| end = cumsum_num_tokens[expert_num + 1] | |
| tokens = permuted_tokens[start:end] | |
| out = expert(tokens) | |
| output[start:end] = out | |
| return output | |
| class MoELayer(nn.Module): | |
| """ | |
| Mixture of Experts (MoE) Layer for the AriaMoE model. | |
| This layer implements the MoE mechanism, which routes input tokens to different experts | |
| based on a routing algorithm, processes them through the experts, and then combines | |
| the outputs. | |
| Args: | |
| config (AriaMoELMConfig): Configuration object for the MoE layer. | |
| """ | |
| def __init__(self, config: AriaMoELMConfig): | |
| super().__init__() | |
| self.router = TopKRouter(config) | |
| self.token_dispatcher = TokenDispatcher(config) | |
| self.experts = SequentialMLP(config) | |
| self.shared_experts = SharedExpertMLP(config) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass of the MoE Layer. | |
| Args: | |
| hidden_states (torch.Tensor): Input tensor of shape (batch_size, sequence_length, hidden_size). | |
| Returns: | |
| torch.Tensor: Output tensor after passing through the MoE layer. | |
| Process: | |
| 1. Route tokens to experts using the router. | |
| 2. Permute tokens based on routing decisions. | |
| 3. Process tokens through experts. | |
| 4. Unpermute and combine expert outputs. | |
| 5. Add shared expert output to the final result. | |
| """ | |
| scores, indices, tokens_per_expert = self.router(hidden_states) | |
| permuted_tokens = self.token_dispatcher.token_permutation( | |
| hidden_states, indices | |
| ) | |
| expert_output = self.experts(permuted_tokens, tokens_per_expert) | |
| output = self.token_dispatcher.token_unpermutation(expert_output, scores) | |
| shared_expert_output = self.shared_experts(hidden_states) | |
| output += shared_expert_output | |
| return output | |
| class MoEDecoderLayer(LlamaDecoderLayer): | |
| """ | |
| Custom Decoder Layer for the AriaMoE model which modifies the standard `LlamaDecoderLayer` by | |
| replacing the traditional MLP with a Mixture of Experts (MoE) Layer. | |
| Args: | |
| config (LlamaConfig): Configuration object for the layer. | |
| layer_idx (int): Index of the current layer in the model. | |
| """ | |
| def __init__(self, config: LlamaConfig, layer_idx: int): | |
| nn.Module.__init__(self) | |
| self.hidden_size = config.hidden_size | |
| self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( | |
| config=config, layer_idx=layer_idx | |
| ) | |
| self.mlp = MoELayer(config) | |
| self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = LlamaRMSNorm( | |
| config.hidden_size, eps=config.rms_norm_eps | |
| ) | |
| class AriaMoELMModel(LlamaModel): | |
| """ | |
| Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by | |
| replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. | |
| This model implements a Mixture of Experts (MoE) approach, where each layer contains | |
| multiple expert networks that specialize in different aspects of the input. | |
| Args: | |
| config (LlamaConfig): Configuration object for the model. | |
| """ | |
| def __init__(self, config: LlamaConfig): | |
| super().__init__(config) | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| self.embed_tokens = nn.Embedding( | |
| config.vocab_size, config.hidden_size, self.padding_idx | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| MoEDecoderLayer(config, layer_idx) | |
| for layer_idx in range(config.num_hidden_layers) | |
| ] | |
| ) | |
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = LlamaRotaryEmbedding(config=config) | |
| self.gradient_checkpointing = False | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| class AriaMoELMForCausalLM(LlamaForCausalLM): | |
| """ | |
| AriaMoE model for causal language modeling tasks. | |
| This class extends LlamaForCausalLM to incorporate the Mixture of Experts (MoE) approach, | |
| allowing for more efficient and scalable language modeling. | |
| Args: | |
| config (AriaMoELMConfig): Configuration object for the model. | |
| """ | |
| _tied_weights_keys = ["lm_head.weight"] | |
| config_class = AriaMoELMConfig | |
| _no_split_modules = ["MoEDecoderLayer"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = AriaMoELMModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def set_z_loss_coeff(self, z_loss_coeff: float): | |
| """ | |
| Set the coefficient for the z-loss in the MoE routing. | |
| Args: | |
| z_loss_coeff (float): The coefficient for the z-loss. | |
| """ | |
| self.config.moe_z_loss_coeff = z_loss_coeff | |
| def set_aux_loss_coeff(self, aux_loss_coeff: float): | |
| """ | |
| Set the coefficient for the auxiliary loss in the MoE routing. | |
| Args: | |
| aux_loss_coeff (float): The coefficient for the auxiliary loss. | |
| """ | |
| self.config.moe_aux_loss_coeff = aux_loss_coeff | |