justbruno commited on
Commit
e205151
·
verified ·
1 Parent(s): 8e73c70

Remove stray cuda call

Browse files

A stray cuda call is preventing this model from being run on machines without a GPU.

This has been changed to a dynamically-chosen device allocation, to match input_ids' location.

Files changed (1) hide show
  1. modeling_aria.py +3 -2
modeling_aria.py CHANGED
@@ -343,6 +343,7 @@ class AriaModel(AriaPreTrainedModel):
343
  if self.causal_mask is None:
344
  self.causal_mask = precompute_causal_mask(
345
  max_seq_len=self.model_config.max_seq_len,
 
346
  ).to(input_ids.device)
347
 
348
  if self.freqs_cis is None:
@@ -617,10 +618,10 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
617
  )
618
 
619
 
620
- def precompute_causal_mask(max_seq_len: int):
621
  return torch.tril(
622
  torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
623
- ).cuda()
624
 
625
 
626
  def precompute_freqs_cis(
 
343
  if self.causal_mask is None:
344
  self.causal_mask = precompute_causal_mask(
345
  max_seq_len=self.model_config.max_seq_len,
346
+ input_ids = input_ids
347
  ).to(input_ids.device)
348
 
349
  if self.freqs_cis is None:
 
618
  )
619
 
620
 
621
+ def precompute_causal_mask(max_seq_len: int, input_ids: torch.Tensor):
622
  return torch.tril(
623
  torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
624
+ ).to(input_ids.device)
625
 
626
 
627
  def precompute_freqs_cis(