| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch OPT model.""" |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| from transformers.activations import ACT2FN |
| from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| QuestionAnsweringModelOutput, |
| SequenceClassifierOutputWithPast, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ( |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| logging, |
| replace_return_docstrings, |
| ) |
| from .configuration_opt import OPTConfig |
|
|
|
|
| if is_flash_attn_2_available(): |
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CHECKPOINT_FOR_DOC = "facebook/opt-350m" |
| _CONFIG_FOR_DOC = "OPTConfig" |
|
|
| |
| _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] |
|
|
| |
| _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" |
| _SEQ_CLASS_EXPECTED_LOSS = 1.71 |
| _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" |
|
|
| OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
| "facebook/opt-125m", |
| "facebook/opt-350m", |
| "facebook/opt-1.3b", |
| "facebook/opt-2.7b", |
| "facebook/opt-6.7b", |
| "facebook/opt-13b", |
| "facebook/opt-30b", |
| |
| ] |
|
|
|
|
| |
| def _get_unpad_data(attention_mask): |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = seqlens_in_batch.max().item() |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| return ( |
| indices, |
| cu_seqlens, |
| max_seqlen_in_batch, |
| ) |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|
| class OPTLearnedPositionalEmbedding(nn.Module): |
| """ |
| This module learns positional embeddings up to a fixed maximum size. |
| """ |
|
|
| def __init__(self, num_embeddings: int, embedding_dim: int): |
| |
| |
| super().__init__() |
| self.offset = 2 |
| self.embeddings = nn.Embedding(num_embeddings + self.offset, embedding_dim) |
|
|
| def forward( |
| self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 |
| ): |
| """`input_ids_shape` is expected to be [bsz x seqlen].""" |
| attention_mask = attention_mask.long() |
|
|
| |
| positions = ( |
| torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask |
| ).long() - 1 |
|
|
| |
| positions = positions[:, past_key_values_length:] |
|
|
| return self.embeddings(positions + self.offset) |
|
|
|
|
| class OPTAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__( |
| self, |
| config: OPTConfig, |
| is_decoder: bool = False, |
| **kwargs, |
| ): |
| super().__init__() |
| self.config = config |
|
|
| def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs): |
| """ |
| If a the deprecated argument `fn_arg_name` is passed, raise a deprecation |
| warning and return that value, otherwise take the equivalent config.config_arg_name |
| """ |
| val = None |
| if fn_arg_name in kwargs: |
| logging.warning( |
| "Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from " |
| "v4.39. Please set it in the config instead" |
| ) |
| val = kwargs.pop(fn_arg_name) |
| else: |
| val = getattr(config, config_arg_name) |
| return val |
|
|
| self.embed_dim = _handle_deprecated_argument( |
| "hidden_size", config, "embed_dim", kwargs |
| ) |
| self.num_heads = _handle_deprecated_argument( |
| "num_attention_heads", config, "num_heads", kwargs |
| ) |
| self.dropout = _handle_deprecated_argument( |
| "attention_dropout", config, "dropout", kwargs |
| ) |
| self.enable_bias = _handle_deprecated_argument( |
| "enable_bias", config, "bias", kwargs |
| ) |
|
|
| self.head_dim = self.embed_dim // self.num_heads |
| self.is_causal = True |
|
|
| if (self.head_dim * self.num_heads) != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
| self.scaling = self.head_dim**-0.5 |
| self.is_decoder = is_decoder |
|
|
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) |
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| return ( |
| tensor.view(bsz, seq_len, self.num_heads, self.head_dim) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| key_value_states: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| |
| |
| is_cross_attention = key_value_states is not None |
|
|
| bsz, tgt_len, _ = hidden_states.size() |
|
|
| |
| query_states = self.q_proj(hidden_states) * self.scaling |
| |
| if is_cross_attention and past_key_value is not None: |
| |
| key_states = past_key_value[0] |
| value_states = past_key_value[1] |
| elif is_cross_attention: |
| |
| key_states = self._shape(self.k_proj(key_value_states), -1, bsz) |
| value_states = self._shape(self.v_proj(key_value_states), -1, bsz) |
| elif past_key_value is not None: |
| |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| else: |
| |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
| if self.is_decoder: |
| |
| |
| |
| |
| |
| |
| |
| past_key_value = (key_states, value_states) |
|
|
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
| key_states = key_states.view(*proj_shape) |
| value_states = value_states.view(*proj_shape) |
|
|
| src_len = key_states.size(1) |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| raise ValueError( |
| f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
| f" {attn_weights.size()}" |
| ) |
|
|
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
| ) |
| attn_weights = ( |
| attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| + attention_mask |
| ) |
| attn_weights = torch.max( |
| attn_weights, |
| torch.tensor( |
| torch.finfo(attn_weights.dtype).min, device=attn_weights.device |
| ), |
| ) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| |
| if attn_weights.dtype == torch.float16: |
| attn_weights = nn.functional.softmax( |
| attn_weights, dim=-1, dtype=torch.float32 |
| ).to(torch.float16) |
| else: |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
| if layer_head_mask is not None: |
| if layer_head_mask.size() != (self.num_heads,): |
| raise ValueError( |
| f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" |
| f" {layer_head_mask.size()}" |
| ) |
| attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( |
| bsz, self.num_heads, tgt_len, src_len |
| ) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| if output_attentions: |
| |
| |
| |
| |
| attn_weights_reshaped = attn_weights.view( |
| bsz, self.num_heads, tgt_len, src_len |
| ) |
| attn_weights = attn_weights_reshaped.view( |
| bsz * self.num_heads, tgt_len, src_len |
| ) |
| else: |
| attn_weights_reshaped = None |
|
|
| attn_probs = nn.functional.dropout( |
| attn_weights, p=self.dropout, training=self.training |
| ) |
|
|
| attn_output = torch.bmm(attn_probs, value_states) |
|
|
| if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
| attn_output = attn_output.transpose(1, 2) |
|
|
| |
| |
| attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
|
|
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output, attn_weights_reshaped, past_key_value |
|
|
|
|
| class OptFlashAttention2(OPTAttention): |
| """ |
| OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. |
| The only required change would be on the forward pass where it needs to correctly call the public API of flash |
| attention and deal with padding tokens in case the input contains any of them. |
| """ |
|
|
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| |
| |
| |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| key_value_states: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| |
| |
| is_cross_attention = key_value_states is not None |
|
|
| bsz, _, _ = hidden_states.size() |
|
|
| |
| query_states = self.q_proj(hidden_states) |
| |
| if is_cross_attention and past_key_value is not None: |
| |
| key_states = past_key_value[0] |
| value_states = past_key_value[1] |
| elif is_cross_attention: |
| |
| key_states = self._shape(self.k_proj(key_value_states), -1, bsz) |
| value_states = self._shape(self.v_proj(key_value_states), -1, bsz) |
| elif past_key_value is not None: |
| |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| else: |
| |
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
| if self.is_decoder: |
| |
| |
| |
| |
| |
| |
| |
| past_key_value = (key_states, value_states) |
|
|
| query_length = query_states.shape[1] |
| tgt_len = key_states.shape[-2] |
|
|
| |
| |
| query_states = query_states.view( |
| bsz, query_length, self.num_heads, self.head_dim |
| ) |
| key_states = key_states.transpose(1, 2).view( |
| bsz, tgt_len, self.num_heads, self.head_dim |
| ) |
| value_states = value_states.transpose(1, 2).view( |
| bsz, tgt_len, self.num_heads, self.head_dim |
| ) |
|
|
| attn_dropout = self.dropout if self.training else 0.0 |
|
|
| |
| |
| |
| input_dtype = query_states.dtype |
| if input_dtype == torch.float32: |
| if torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| |
| elif hasattr(self.config, "_pre_quantization_dtype"): |
| target_dtype = self.config._pre_quantization_dtype |
| else: |
| target_dtype = self.q_proj.weight.dtype |
|
|
| logger.warning_once( |
| f"The input hidden states seems to be silently casted in float32, this might be related to" |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| f" {target_dtype}." |
| ) |
|
|
| query_states = query_states.to(target_dtype) |
| key_states = key_states.to(target_dtype) |
| value_states = value_states.to(target_dtype) |
|
|
| attn_output = self._flash_attention_forward( |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| query_length, |
| dropout=attn_dropout, |
| ) |
|
|
| attn_weights_reshaped = attn_output.reshape( |
| bsz, query_length, self.num_heads * self.head_dim |
| ) |
| attn_output = self.out_proj(attn_weights_reshaped) |
|
|
| if not output_attentions: |
| attn_weights_reshaped = None |
|
|
| return attn_output, attn_weights_reshaped, past_key_value |
|
|
| |
| def _flash_attention_forward( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| query_length, |
| dropout=0.0, |
| softmax_scale=None, |
| ): |
| """ |
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| first unpad the input, then computes the attention scores and pad the final attention scores. |
| |
| Args: |
| query_states (`torch.Tensor`): |
| Input query states to be passed to Flash Attention API |
| key_states (`torch.Tensor`): |
| Input key states to be passed to Flash Attention API |
| value_states (`torch.Tensor`): |
| Input value states to be passed to Flash Attention API |
| attention_mask (`torch.Tensor`): |
| The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| position of padding tokens and 1 for the position of non-padding tokens. |
| dropout (`int`, *optional*): |
| Attention dropout |
| softmax_scale (`float`, *optional*): |
| The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
| """ |
| if not self._flash_attn_uses_top_left_mask: |
| causal = self.is_causal |
| else: |
| |
| causal = self.is_causal and query_length != 1 |
|
|
| |
| if attention_mask is not None: |
| batch_size = query_states.shape[0] |
| ( |
| query_states, |
| key_states, |
| value_states, |
| indices_q, |
| cu_seq_lens, |
| max_seq_lens, |
| ) = self._upad_input( |
| query_states, key_states, value_states, attention_mask, query_length |
| ) |
|
|
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
| attn_output_unpad = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
|
|
| attn_output = pad_input( |
| attn_output_unpad, indices_q, batch_size, query_length |
| ) |
| else: |
| attn_output = flash_attn_func( |
| query_states, |
| key_states, |
| value_states, |
| dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
|
|
| return attn_output |
|
|
| |
| def _upad_input( |
| self, query_layer, key_layer, value_layer, attention_mask, query_length |
| ): |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
|
|
| key_layer = index_first_axis( |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
| indices_k, |
| ) |
| value_layer = index_first_axis( |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
| indices_k, |
| ) |
| if query_length == kv_seq_len: |
| query_layer = index_first_axis( |
| query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), |
| indices_k, |
| ) |
| cu_seqlens_q = cu_seqlens_k |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| indices_q = indices_k |
| elif query_length == 1: |
| max_seqlen_in_batch_q = 1 |
| cu_seqlens_q = torch.arange( |
| batch_size + 1, dtype=torch.int32, device=query_layer.device |
| ) |
| indices_q = cu_seqlens_q[:-1] |
| query_layer = query_layer.squeeze(1) |
| else: |
| |
| attention_mask = attention_mask[:, -query_length:] |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( |
| query_layer, attention_mask |
| ) |
|
|
| return ( |
| query_layer, |
| key_layer, |
| value_layer, |
| indices_q, |
| (cu_seqlens_q, cu_seqlens_k), |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| ) |
|
|
|
|
| OPT_ATTENTION_CLASSES = { |
| "eager": OPTAttention, |
| "flash_attention_2": OptFlashAttention2, |
| } |
|
|
|
|
| class OPTDecoderLayer(nn.Module): |
| def __init__(self, config: OPTConfig): |
| super().__init__() |
| self.embed_dim = config.hidden_size |
|
|
| self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation]( |
| config=config, is_decoder=True |
| ) |
|
|
| self.do_layer_norm_before = config.do_layer_norm_before |
| self.dropout = config.dropout |
| self.activation_fn = ACT2FN[config.activation_function] |
|
|
| self.self_attn_layer_norm = nn.LayerNorm( |
| self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine |
| ) |
| self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) |
| self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) |
| self.final_layer_norm = nn.LayerNorm( |
| self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| ) -> Tuple[ |
| torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
| ]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size |
| `(encoder_attention_heads,)`. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| """ |
|
|
| residual = hidden_states |
|
|
| |
| if self.do_layer_norm_before: |
| hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| past_key_value=past_key_value, |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = nn.functional.dropout( |
| hidden_states, p=self.dropout, training=self.training |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| if not self.do_layer_norm_before: |
| hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
| |
| hidden_states_shape = hidden_states.shape |
| hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) |
| residual = hidden_states |
|
|
| |
| if self.do_layer_norm_before: |
| hidden_states = self.final_layer_norm(hidden_states) |
|
|
| hidden_states = self.fc1(hidden_states) |
| hidden_states = self.activation_fn(hidden_states) |
|
|
| hidden_states = self.fc2(hidden_states) |
| hidden_states = nn.functional.dropout( |
| hidden_states, p=self.dropout, training=self.training |
| ) |
|
|
| hidden_states = (residual + hidden_states).view(hidden_states_shape) |
|
|
| |
| if not self.do_layer_norm_before: |
| hidden_states = self.final_layer_norm(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
|
|
| OPT_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`OPTConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare OPT Model outputting raw hidden-states without any specific head on top.", |
| OPT_START_DOCSTRING, |
| ) |
| class OPTPreTrainedModel(PreTrainedModel): |
| config_class = OPTConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["OPTDecoderLayer"] |
| _supports_flash_attn_2 = True |
|
|
| def _init_weights(self, module): |
| std = self.config.init_std |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| OPT_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
| `past_key_values`). |
| |
| If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| information on the default strategy. |
| head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): |
| Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
| `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
| blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
| don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
| `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| class OPTDecoder(OPTPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] |
| |
| Args: |
| config: OPTConfig |
| """ |
|
|
| def __init__(self, config: OPTConfig): |
| super().__init__(config) |
| self.dropout = config.dropout |
| self.layerdrop = config.layerdrop |
| self.padding_idx = config.pad_token_id |
| self.max_target_positions = config.max_position_embeddings |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding( |
| config.vocab_size, config.word_embed_proj_dim, self.padding_idx |
| ) |
| self._embed_positions = OPTLearnedPositionalEmbedding( |
| config.max_position_embeddings, config.hidden_size |
| ) |
| self.embed_positions = self._embed_positions.embeddings |
|
|
| if config.word_embed_proj_dim != config.hidden_size: |
| self.project_out = nn.Linear( |
| config.hidden_size, config.word_embed_proj_dim, bias=False |
| ) |
| else: |
| self.project_out = None |
|
|
| if config.word_embed_proj_dim != config.hidden_size: |
| self.project_in = nn.Linear( |
| config.word_embed_proj_dim, config.hidden_size, bias=False |
| ) |
| else: |
| self.project_in = None |
|
|
| |
| |
| |
| if config.do_layer_norm_before and not config._remove_final_layer_norm: |
| self.final_layer_norm = nn.LayerNorm( |
| config.hidden_size, |
| elementwise_affine=config.layer_norm_elementwise_affine, |
| ) |
| else: |
| self.final_layer_norm = None |
|
|
| self.layers = nn.ModuleList( |
| [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)] |
| ) |
| self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
| provide it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): |
| Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the |
| cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those |
| that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of |
| all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError( |
| "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" |
| ) |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError( |
| "You have to specify either decoder_input_ids or decoder_inputs_embeds" |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| batch_size, seq_length = input_shape |
| past_key_values_length = ( |
| past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
| ) |
| |
| mask_seq_length = past_key_values_length + seq_length |
|
|
| |
| if self._use_flash_attention_2: |
| |
| causal_attention_mask = ( |
| attention_mask |
| if (attention_mask is not None and 0 in attention_mask) |
| else None |
| ) |
| attention_mask = ( |
| torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| if attention_mask is None |
| else attention_mask |
| ) |
| else: |
| |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| batch_size, mask_seq_length, device=inputs_embeds.device |
| ) |
| elif attention_mask.shape[1] != mask_seq_length: |
| raise ValueError( |
| f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " |
| f"{mask_seq_length} (sum of the lengths of current and past inputs)" |
| ) |
| causal_attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, input_shape, inputs_embeds, past_key_values_length |
| ) |
|
|
| pos_embeds = self._embed_positions(attention_mask, past_key_values_length) |
|
|
| if self.project_in is not None: |
| inputs_embeds = self.project_in(inputs_embeds) |
|
|
| hidden_states = inputs_embeds + pos_embeds |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = () if use_cache else None |
|
|
| |
| for attn_mask, mask_name in zip([head_mask], ["head_mask"]): |
| if attn_mask is not None: |
| if attn_mask.size()[0] != (len(self.layers)): |
| raise ValueError( |
| f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" |
| f" {head_mask.size()[0]}." |
| ) |
|
|
| for idx, decoder_layer in enumerate(self.layers): |
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.training: |
| dropout_probability = torch.rand([]) |
| if dropout_probability < self.layerdrop: |
| continue |
|
|
| past_key_value = ( |
| past_key_values[idx] if past_key_values is not None else None |
| ) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_attention_mask, |
| head_mask[idx] if head_mask is not None else None, |
| None, |
| output_attentions, |
| use_cache, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_attention_mask, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| if self.final_layer_norm is not None: |
| hidden_states = self.final_layer_norm(hidden_states) |
|
|
| if self.project_out is not None: |
| hidden_states = self.project_out(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] |
| if v is not None |
| ) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| "The bare OPT Model outputting raw hidden-states without any specific head on top.", |
| OPT_START_DOCSTRING, |
| ) |
| class OPTModel(OPTPreTrainedModel): |
| def __init__(self, config: OPTConfig): |
| super().__init__(config) |
| self.decoder = OPTDecoder(config) |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.decoder.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.decoder.embed_tokens = value |
|
|
| def get_decoder(self): |
| return self.decoder |
|
|
| @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutputWithPast, |
| config_class=_CONFIG_FOR_DOC, |
| expected_output=_EXPECTED_OUTPUT_SHAPE, |
| ) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
| decoder_outputs = self.decoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| if not return_dict: |
| return decoder_outputs |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=decoder_outputs.last_hidden_state, |
| past_key_values=decoder_outputs.past_key_values, |
| hidden_states=decoder_outputs.hidden_states, |
| attentions=decoder_outputs.attentions, |
| ) |
|
|
|
|
| class OPTForCausalLM(OPTPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = OPTModel(config) |
|
|
| |
| self.lm_head = nn.Linear( |
| config.word_embed_proj_dim, config.vocab_size, bias=False |
| ) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.decoder.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.decoder.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model.decoder = decoder |
|
|
| def get_decoder(self): |
| return self.model.decoder |
|
|
| @replace_return_docstrings( |
| output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC |
| ) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
| provide it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): |
| Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of |
| shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional |
| tensors are only required when the model is used as a decoder in a Sequence to Sequence model. |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the |
| cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those |
| that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of |
| all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, OPTForCausalLM |
| |
| >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") |
| >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." |
| ```""" |
|
|
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
| outputs = self.model.decoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| logits = self.lm_head(outputs[0]).contiguous() |
|
|
| loss = None |
| if labels is not None: |
| |
| labels = labels.to(logits.device) |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1) |
| ) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| **kwargs, |
| ): |
| if past_key_values is not None: |
| past_length = past_key_values[0][0].shape[2] |
|
|
| |
| if input_ids.shape[1] > past_length: |
| remove_prefix_length = past_length |
| else: |
| |
| remove_prefix_length = input_ids.shape[1] - 1 |
|
|
| input_ids = input_ids[:, remove_prefix_length:] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update( |
| { |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple( |
| past_state.index_select(0, beam_idx.to(past_state.device)) |
| for past_state in layer_past |
| ), |
| ) |
| return reordered_past |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The OPT Model transformer with a sequence classification head on top (linear layer). |
| |
| [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
| (e.g. GPT-2) do. |
| |
| Since it does classification on the last token, it requires to know the position of the last token. If a |
| `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
| no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
| padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
| each row of the batch). |
| """, |
| OPT_START_DOCSTRING, |
| ) |
| class OPTForSequenceClassification(OPTPreTrainedModel): |
| def __init__(self, config: OPTConfig): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = OPTModel(config) |
| self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, |
| output_type=SequenceClassifierOutputWithPast, |
| config_class=_CONFIG_FOR_DOC, |
| expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, |
| expected_loss=_SEQ_CLASS_EXPECTED_LOSS, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| transformer_outputs = self.model( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size, sequence_length = input_ids.shape[:2] |
| else: |
| batch_size, sequence_length = inputs_embeds.shape[:2] |
|
|
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| |
| sequence_lengths = ( |
| torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
| ) |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| sequence_lengths = sequence_lengths.to(logits.device) |
| else: |
| sequence_lengths = -1 |
| logger.warning( |
| f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
| "unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
| ) |
|
|
| pooled_logits = logits[ |
| torch.arange(batch_size, device=logits.device), sequence_lengths |
| ] |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and ( |
| labels.dtype == torch.long or labels.dtype == torch.int |
| ): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct( |
| pooled_logits.view(-1, self.num_labels), labels.view(-1) |
| ) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| if not return_dict: |
| output = (pooled_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
| def get_input_embeddings(self): |
| return self.model.decoder.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.decoder.embed_tokens = value |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD |
| (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). |
| """, |
| OPT_START_DOCSTRING, |
| ) |
| class OPTForQuestionAnswering(OPTPreTrainedModel): |
| def __init__(self, config: OPTConfig): |
| super().__init__(config) |
| self.model = OPTModel(config) |
| self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) |
| @replace_return_docstrings( |
| output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| start_positions: Optional[torch.LongTensor] = None, |
| end_positions: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, QuestionAnsweringModelOutput]: |
| r""" |
| start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the start of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| are not taken into account for computing the loss. |
| end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the end of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| are not taken into account for computing the loss. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, OPTForQuestionAnswering |
| >>> import torch |
| |
| >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT |
| >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") |
| |
| >>> # note: we are loading a OPTForQuestionAnswering from the hub here, |
| >>> # so the head will be randomly initialized, hence the predictions will be random |
| >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") |
| |
| >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" |
| |
| >>> inputs = tokenizer(question, text, return_tensors="pt") |
| >>> with torch.no_grad(): |
| ... outputs = model(**inputs) |
| |
| >>> answer_start_index = outputs.start_logits.argmax() |
| >>> answer_end_index = outputs.end_logits.argmax() |
| |
| >>> answer_offset = len(tokenizer(question)[0]) |
| |
| >>> predict_answer_tokens = inputs.input_ids[ |
| ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 |
| ... ] |
| >>> predicted = tokenizer.decode(predict_answer_tokens) |
| >>> predicted |
| ' a nice puppet' |
| ```""" |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| transformer_outputs = self.model( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
|
|
| logits = self.qa_outputs(hidden_states) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1) |
| |
| ignored_index = start_logits.size(1) |
| start_positions = start_positions.clamp(0, ignored_index) |
| end_positions = end_positions.clamp(0, ignored_index) |
|
|
| loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
|
|
| if not return_dict: |
| output = (start_logits, end_logits) + transformer_outputs[2:] |
| return ((total_loss,) + output) if total_loss is not None else output |
|
|
| return QuestionAnsweringModelOutput( |
| loss=total_loss, |
| start_logits=start_logits, |
| end_logits=end_logits, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
| def get_input_embeddings(self): |
| return self.model.decoder.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.decoder.embed_tokens = value |