| |
|
|
| |
| |
| |
|
|
|
|
| |
| |
|
|
| from functools import partial |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
| from torch.distributed import DeviceMesh |
| from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module |
| from torch.distributed.tensor.parallel import ParallelStyle |
|
|
| |
| |
| |
| |
| MAX_FUSED_SIZE = 65536 // 2 |
|
|
|
|
| @triton.heuristics({ |
| 'HAS_SCALE': lambda args: args['scale'] is not None |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=num_warps) |
| for num_warps in [1, 2, 4, 8, 16, 32] |
| ], |
| key=['D'] |
| ) |
| @triton.jit |
| def logsumexp_fwd_kernel( |
| x, |
| z, |
| scale, |
| D: tl.constexpr, |
| B: tl.constexpr, |
| HAS_SCALE: tl.constexpr |
| ): |
| i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) |
| o_d = i_d * B + tl.arange(0, B) |
| m_d = o_d < D |
|
|
| b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) |
| if HAS_SCALE: |
| b_x = b_x * scale |
| b_m = tl.max(b_x, 0) |
| b_z = tl.log(tl.sum(tl.exp(b_x - b_m), 0)) + b_m |
| tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) |
|
|
|
|
| def logsumexp_fwd( |
| x, |
| scale: Optional[float] = None, |
| dtype: Optional[torch.dtype] = None |
| ): |
| r""" |
| Compute the logsumexp of the input tensor over the last dimension. |
| |
| Args: |
| x (Tensor): |
| The input tensor of any shape. |
| scale (Optional[float]): |
| The scale applied to the input tensor. Default: `None`. |
| dtype (Optional[torch.dtype]): |
| The data type of the output tensor. Default: `None`. |
| Returns: |
| Tensor: The logsumexp of the input tensor. |
| """ |
|
|
| shape = x.shape |
| x = x.view(-1, shape[-1]) |
| N, D = x.shape |
| B = min(triton.next_power_of_2(D), 64 * 1024) |
| ND = triton.cdiv(D, B) |
|
|
| z = x.new_empty(N, ND, dtype=torch.float) |
| logsumexp_fwd_kernel[(N, ND)]( |
| x=x, |
| z=z, |
| scale=scale, |
| D=D, |
| B=B |
| ) |
| z = z.logsumexp(-1).view(*shape[:-1]) |
| if dtype is not None and dtype != torch.float: |
| z = z.to(dtype) |
| return z |
|
|
| @triton.jit |
| def cross_entropy_kernel( |
| logits, |
| lse, |
| target, |
| p_mask, |
| loss, |
| total, |
| ignore_index, |
| label_smoothing: tl.constexpr, |
| logit_scale: tl.constexpr, |
| reduction: tl.constexpr, |
| V: tl.constexpr, |
| BV: tl.constexpr |
| ): |
| """ |
| This kernel computes both cross entropy loss and the gradient of the input. |
| We only consider hard label + mean reduction for now. |
| Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. |
| |
| Args: |
| logits: |
| Pointer to logits tensor. |
| lse: |
| Pointer to logsumexp tensor. |
| target: Pointer to target tensor. |
| loss: |
| Pointer to tensor to store the loss. |
| V (int): |
| The number of columns in the input tensor. |
| total (int): |
| The number of non-ignored classes. |
| ignore_index (int): |
| The index to ignore in the target. |
| label_smoothing (float): |
| The amount of smoothing when computing the loss, where 0.0 means no smoothing. |
| reduction (str): |
| The string for the reduction to apply |
| BV (int): |
| The block size for vocab. |
| """ |
|
|
| |
| |
| i_n = tl.program_id(0).to(tl.int64) |
| NV = tl.cdiv(V, BV) |
|
|
| |
| b_y = tl.load(target + i_n) |
| |
| b_p_mask = tl.load(p_mask + i_n) |
|
|
| |
| logits += i_n * V |
|
|
| if b_y == ignore_index: |
| |
| for i in range(0, V, BV): |
| o_v = i + tl.arange(0, BV) |
| tl.store(logits + o_v, 0.0, mask=o_v < V) |
| return |
|
|
| |
| |
|
|
| |
| |
| b_l = tl.load(logits + b_y) * logit_scale |
| b_lse = tl.load(lse + i_n) |
|
|
| |
| |
| |
| b_loss = (b_lse - b_l) / b_p_mask |
|
|
| |
| |
| b_z = 0.0 |
| eps = label_smoothing / V |
|
|
| |
| |
| tl.debug_barrier() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| for iv in range(0, NV): |
| o_v = iv * BV + tl.arange(0, BV) |
| b_logits = tl.load(logits + o_v, mask=o_v < V, other=float('-inf')) * logit_scale |
| if label_smoothing > 0: |
| |
| b_z += tl.sum(tl.where(o_v < V, -eps * b_logits, 0.0)) |
| b_p = (tl.exp(b_logits - b_lse) - eps) * logit_scale |
| b_p /= b_p_mask |
| if reduction == "mean": |
| b_p = b_p / total |
| tl.store(logits + o_v, b_p, mask=o_v < V) |
|
|
| tl.debug_barrier() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if label_smoothing > 0: |
| b_loss = b_loss * (1 - label_smoothing) + (b_z + label_smoothing * b_lse) |
|
|
| |
| b_l = tl.load(logits + b_y) |
|
|
| |
| if reduction == 'mean': |
| b_loss = b_loss / total |
| |
| |
| b_l += (label_smoothing - 1) / b_p_mask / total * logit_scale |
| else: |
| |
| b_l += (label_smoothing - 1) / b_p_mask * logit_scale |
|
|
| tl.store(loss + i_n, b_loss) |
| tl.store(logits + b_y, b_l) |
|
|
|
|
| @triton.jit |
| def elementwise_mul_kernel( |
| x, |
| g, |
| N: tl.constexpr, |
| B: tl.constexpr |
| ): |
| """ |
| This function multiplies each element of the tensor pointed by x with the value pointed by g. |
| The multiplication is performed in-place on the tensor pointed by x. |
| |
| Parameters: |
| x: |
| Pointer to the input tensor. |
| g: |
| Pointer to the gradient output value. |
| N (int): |
| The number of columns in the input tensor. |
| B (int): |
| The block size for Triton operations. |
| """ |
|
|
| |
| i_x = tl.program_id(0).to(tl.int64) |
| o_x = i_x * B + tl.arange(0, B) |
|
|
| |
| b_g = tl.load(g) |
| b_x = tl.load(x + o_x, mask=o_x < N) |
| tl.store(x + o_x, b_x * b_g, mask=o_x < N) |
|
|
|
|
| def fused_linear_cross_entropy_forward( |
| x: torch.Tensor, |
| target: torch.LongTensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor = None, |
| p_mask: torch.Tensor = None, |
| ignore_index: int = -100, |
| label_smoothing: float = 0.0, |
| logit_scale: float = 1.0, |
| num_chunks: int = 8, |
| reduction: str = "mean" |
| ): |
| device = x.device |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| N, H, V = *x.shape, weight.shape[0] |
| BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) |
| |
| |
| NC = min(num_chunks, triton.cdiv(V, H)) |
| C = triton.next_power_of_2(triton.cdiv(N, NC)) |
| NC = triton.cdiv(N, C) |
|
|
| |
| dx = torch.zeros_like(x, device=device) |
| |
| dw = torch.zeros_like(weight, device=device, dtype=torch.float) if weight is not None else None |
| |
| db = torch.zeros_like(bias, device=device, dtype=torch.float) if bias is not None else None |
| |
| loss = torch.zeros(N, device=device, dtype=torch.float) |
|
|
| total = target.ne(ignore_index).sum().item() |
|
|
| for ic in range(NC): |
| start, end = ic * C, min((ic + 1) * C, N) |
| |
| c_x = x[start:end] |
| |
| |
| c_logits = F.linear(c_x, weight, bias) |
| c_target = target[start:end] |
| c_p_mask = p_mask[start:end] |
| |
| |
| c_lse = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float) |
|
|
| |
| c_loss = loss[start:end] |
|
|
| |
| cross_entropy_kernel[(c_logits.shape[0],)]( |
| logits=c_logits, |
| lse=c_lse, |
| target=c_target, |
| p_mask=c_p_mask, |
| loss=c_loss, |
| total=total, |
| ignore_index=ignore_index, |
| label_smoothing=label_smoothing, |
| logit_scale=logit_scale, |
| reduction=reduction, |
| V=V, |
| BV=BV, |
| num_warps=32 |
| ) |
|
|
| |
| |
| dx[start:end] = torch.mm(c_logits, weight) |
|
|
| |
| if weight is not None: |
| dw += c_logits.t() @ c_x |
|
|
| if bias is not None: |
| torch.add(input=db, other=c_logits.sum(0), out=db) |
|
|
| loss = loss.sum() |
| if dw is not None: |
| dw = dw.to(weight) |
| if db is not None: |
| db = db.to(bias) |
| return loss, dx, dw, db |
|
|
|
|
| def fused_linear_cross_entropy_backward( |
| do: torch.Tensor, |
| dx: torch.Tensor, |
| dw: torch.Tensor, |
| db: torch.Tensor |
| ): |
| |
| if torch.ne(do, torch.tensor(1.0, device=do.device)): |
| |
| |
| N, H = dx.shape |
| B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) |
|
|
| elementwise_mul_kernel[(triton.cdiv(N * H, B),)]( |
| x=dx, |
| g=do, |
| N=N*H, |
| B=B, |
| num_warps=32, |
| ) |
|
|
| |
| if dw is not None: |
| V, H = dw.shape |
| elementwise_mul_kernel[(triton.cdiv(V * H, B),)]( |
| x=dw, |
| g=do, |
| N=V*H, |
| B=B, |
| num_warps=32, |
| ) |
|
|
| if db is not None: |
| V = db.shape[0] |
| elementwise_mul_kernel[(triton.cdiv(V, B),)]( |
| x=db, |
| g=do, |
| N=V, |
| B=B, |
| num_warps=32, |
| ) |
| return dx, dw, db |
|
|
|
|
| class FusedLinearCrossEntropyFunction(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward( |
| ctx, |
| x: torch.Tensor, |
| target: torch.LongTensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor = None, |
| p_mask: torch.Tensor = None, |
| ignore_index: int = -100, |
| label_smoothing: float = 0.0, |
| logit_scale: float = 1.0, |
| num_chunks: int = 8, |
| reduction: str = "mean" |
| ): |
| """ |
| Fusing the last linear layer with cross-entropy loss |
| Reference: https://github.com/mgmalek/efficient_cross_entropy |
| |
| Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding |
| the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can |
| compute the gradient at the forward pass. By doing so, we don't have to store the x and target |
| for the backward pass. |
| |
| x (torch.Tensor): [batch_size * seq_len, hidden_size] |
| target (torch.LongTensor): [batch_size * seq_len] |
| where each value is in [0, vocab_size). |
| weight (torch.Tensor): [vocab_size, hidden_size] |
| where `vocab_size` is the number of classes. |
| bias (Optional[torch.Tensor]): [vocab_size] |
| where `vocab_size` is the number of classes. |
| p_mask(torch.Tensor): [batch_size * seq_len] |
| Its shape should be same as target. |
| ignore_index: |
| the index to ignore in the target. |
| label_smoothing: |
| the amount of smoothing when computing the loss, where 0.0 means no smoothing. |
| logit_scale: float = 1.0, |
| A scaling factor applied to the logits. Default: 1.0 |
| num_chunks: int |
| The number of chunks to split the input tensor into for processing. |
| This can help optimize memory usage and computation speed. |
| Default: 8 |
| reduction: |
| Specifies the reduction to apply to the output: 'mean' | 'sum'. |
| 'mean': the weighted mean of the output is taken, |
| 'sum': the output will be summed. |
| Default: 'mean'. |
| """ |
| loss, dx, dw, db = fused_linear_cross_entropy_forward( |
| x, |
| target, |
| weight, |
| bias, |
| p_mask, |
| ignore_index, |
| label_smoothing, |
| logit_scale, |
| num_chunks, |
| reduction |
| ) |
| |
| ctx.save_for_backward( |
| dx.detach(), |
| dw.detach() if weight is not None else None, |
| db.detach() if bias is not None else None, |
| ) |
| return loss |
|
|
| @staticmethod |
| def backward(ctx, do): |
| dx, dw, db = ctx.saved_tensors |
| dx, dw, db = fused_linear_cross_entropy_backward(do, dx, dw, db) |
| |
| |
| return dx, None, dw, db, None, None, None, None, None, None |
|
|
|
|
| def fused_linear_cross_entropy_loss( |
| x: torch.Tensor, |
| target: torch.LongTensor, |
| weight: torch.Tensor, |
| bias: torch.Tensor = None, |
| p_mask: torch.Tensor = None, |
| ignore_index: int = -100, |
| label_smoothing: float = 0.0, |
| logit_scale: float = 1.0, |
| num_chunks: int = 8, |
| reduction: str = "mean" |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x (torch.Tensor): [batch_size * seq_len, hidden_size] |
| target (torch.LongTensor): [batch_size * seq_len] |
| where each value is in [0, vocab_size). |
| weight (torch.Tensor): [vocab_size, hidden_size] |
| where `vocab_size` is the number of classes. |
| bias (Optional[torch.Tensor]): [vocab_size] |
| where `vocab_size` is the number of classes. |
| p_mask(torch.Tensor): [batch_size * seq_len] |
| Its shape should be same as target. |
| ignore_index: int. |
| If target == ignore_index, the loss is set to 0.0. |
| label_smoothing: float |
| logit_scale: float |
| A scaling factor applied to the logits. Default: 1.0 |
| num_chunks: int |
| The number of chunks to split the input tensor into for processing. |
| This can help optimize memory usage and computation speed. |
| Default: 8 |
| reduction: |
| Specifies the reduction to apply to the output: 'mean' | 'sum'. |
| 'mean': the weighted mean of the output is taken, |
| 'sum': the output will be summed. |
| Default: 'mean'. |
| Returns: |
| losses: [batch,], float |
| """ |
| return FusedLinearCrossEntropyFunction.apply( |
| x, |
| target, |
| weight, |
| bias, |
| p_mask, |
| ignore_index, |
| label_smoothing, |
| logit_scale, |
| num_chunks, |
| reduction |
| ) |
|
|
|
|
| class FusedLinearDiffusionCrossEntropyLoss(nn.Module): |
|
|
| def __init__( |
| self, |
| ignore_index: int = -100, |
| label_smoothing: float = 0.0, |
| logit_scale: float = 1.0, |
| num_chunks: int = 8, |
| reduction: str = "mean" |
| ): |
| """ |
| Args: |
| ignore_index: int. |
| If target == ignore_index, the loss is set to 0.0. |
| label_smoothing: float |
| logit_scale: float |
| A scaling factor applied to the logits. Default: 1.0 |
| num_chunks: int |
| The number of chunks to split the input tensor into for processing. |
| This can help optimize memory usage and computation speed. |
| Default: 8 |
| reduction: |
| Specifies the reduction to apply to the output: 'mean' | 'sum'. |
| 'mean': the weighted mean of the output is taken, |
| 'sum': the output will be summed. |
| Default: 'mean'. |
| """ |
| super().__init__() |
|
|
| assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported" |
|
|
| self.ignore_index = ignore_index |
| self.label_smoothing = label_smoothing |
| self.logit_scale = logit_scale |
| self.num_chunks = num_chunks |
| self.reduction = reduction |
|
|
| @torch.compiler.disable |
| def forward( |
| self, |
| x: torch.Tensor, |
| target: torch.LongTensor, |
| weight: torch.Tensor, |
| bias: Optional[torch.Tensor] = None, |
| p_mask: torch.Tensor = None |
| ): |
| """ |
| Args: |
| x (torch.Tensor): [batch_size, seq_len, hidden_size] |
| target (torch.LongTensor): [batch_size, seq_len] |
| where each value is in [0, V). |
| weight (torch.Tensor): [vocab_size, hidden_size] |
| where `vocab_size` is the number of classes. |
| bias (Optional[torch.Tensor]): [vocab_size] |
| where `vocab_size` is the number of classes. |
| p_mask(torch.Tensor): [batch_size, seq_len] |
| Its shape is same as target. |
| Shape: (1, packed_length) when varlen attn is used. |
| Returns: |
| loss |
| |
| TODO: |
| follow https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training |
| ```py |
| unreduced_loss /= p_mask |
| ``` |
| Scale the values of `unreduced_loss at different positions |
| """ |
| if p_mask is None: |
| p_mask = torch.ones_like(target, dtype=torch.float, device=x.device) |
| |
| x = x.contiguous().view(-1, x.shape[-1]) |
| target = target.contiguous().view(-1) |
| weight = weight.contiguous() |
| bias = bias.contiguous() if bias else None |
| p_mask = p_mask.contiguous().view(-1) |
| l, d = x.shape |
| assert l == target.shape[0] == p_mask.shape[0], f"{x.shape=}, {target.shape=}, {p_mask.shape=}" |
| |
| loss = fused_linear_cross_entropy_loss( |
| x, |
| target, |
| weight=weight, |
| bias=bias, |
| p_mask=p_mask, |
| ignore_index=self.ignore_index, |
| label_smoothing=self.label_smoothing, |
| logit_scale=self.logit_scale, |
| num_chunks=self.num_chunks, |
| reduction=self.reduction |
| ) |
| return loss |
|
|
|
|
| class LinearLossParallel(ParallelStyle): |
| def __init__( |
| self, |
| *, |
| sequence_dim: int = 1, |
| use_local_output: bool = False, |
| ): |
| super().__init__() |
|
|
| self.sequence_sharding = (Shard(sequence_dim),) |
| self.use_local_output = use_local_output |
|
|
| @staticmethod |
| def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): |
| x, target, weight, bias = inputs |
|
|
| if not isinstance(x, DTensor): |
| |
| x = DTensor.from_local(x, device_mesh, sequence_sharding) |
| if x.placements != sequence_sharding: |
| x = x.redistribute(placements=sequence_sharding, async_op=True) |
| if not isinstance(target, DTensor): |
| target = DTensor.from_local(target, device_mesh, [Replicate()]) |
| if target.placements != sequence_sharding: |
| target = target.redistribute(placements=sequence_sharding, async_op=True) |
|
|
| if not isinstance(weight, DTensor): |
| weight = DTensor.from_local(weight, device_mesh, [Replicate()]) |
| if weight.placements != [Replicate()]: |
| |
| weight = weight.redistribute(placements=[Replicate()], async_op=True) |
|
|
| if bias is not None and not isinstance(bias, DTensor): |
| bias = DTensor.from_local(bias, device_mesh, [Replicate()]) |
| if bias is not None and bias.placements != [Replicate()]: |
| bias = bias.redistribute(placements=[Replicate()], async_op=True) |
|
|
| return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias |
|
|
| @staticmethod |
| def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): |
| return outputs.to_local() if use_local_output else outputs |
|
|
| def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| return distribute_module( |
| module, |
| device_mesh, |
| partition_fn=None, |
| input_fn=partial(self._prepare_input_fn, self.sequence_sharding), |
| output_fn=partial(self._prepare_output_fn, self.use_local_output) |
| ) |
|
|