ridger sirorezka commited on
Commit
d917c94
·
1 Parent(s): 2d13cf8

rope_type='default' excluded from ROPE_INIT_FUNCTIONS in transfomers >=5.0 (#6)

Browse files

- rope_type='default' excluded from ROPE_INIT_FUNCTIONS in transfomers >=5.0 (c49969e88ea97b05838040e8b0333a0f63a8fe9f)


Co-authored-by: Petrov <sirorezka@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_ouro.py +28 -2
modeling_ouro.py CHANGED
@@ -478,12 +478,38 @@ class OuroRotaryEmbedding(nn.Module):
478
  self.original_max_seq_len = config.max_position_embeddings
479
 
480
  self.config = config
481
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
 
 
482
 
483
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
484
  self.register_buffer("inv_freq", inv_freq, persistent=False)
485
  self.original_inv_freq = self.inv_freq
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  @torch.no_grad()
488
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
489
  def forward(self, x, position_ids):
 
478
  self.original_max_seq_len = config.max_position_embeddings
479
 
480
  self.config = config
481
+
482
+ rope_init_fn: Callable = self.compute_default_rope_parameters
483
+ if self.rope_type != "default":
484
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
485
 
486
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
487
  self.register_buffer("inv_freq", inv_freq, persistent=False)
488
  self.original_inv_freq = self.inv_freq
489
 
490
+
491
+ @staticmethod
492
+ def compute_default_rope_parameters(
493
+ config: Optional[OuroConfig] = None,
494
+ device: Optional["torch.device"] = None,
495
+ seq_len: Optional[int] = None,
496
+ ) -> tuple["torch.Tensor", float]:
497
+ """
498
+ Computes the inverse frequencies according to the original RoPE implementation
499
+ """
500
+
501
+ base = config.rope_theta
502
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
503
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
504
+ dim = int(head_dim * partial_rotary_factor)
505
+
506
+ attention_factor = 1.0 # Unused in this type of RoPE
507
+
508
+ # Compute the inverse frequencies
509
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
510
+ return inv_freq, attention_factor
511
+
512
+
513
  @torch.no_grad()
514
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
515
  def forward(self, x, position_ids):