justbruno commited on
Commit
4fe741c
·
verified ·
1 Parent(s): 8e73c70

Removed stray cuda call

Browse files

A stray cuda call was preventing this model from being used on devices without a GPU or TPU.

The causal mask is allocated to input_ids.device upon return, as it should.

Files changed (1) hide show
  1. modeling_aria.py +1 -1
modeling_aria.py CHANGED
@@ -620,7 +620,7 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
620
  def precompute_causal_mask(max_seq_len: int):
621
  return torch.tril(
622
  torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
623
- ).cuda()
624
 
625
 
626
  def precompute_freqs_cis(
 
620
  def precompute_causal_mask(max_seq_len: int):
621
  return torch.tril(
622
  torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
623
+ )
624
 
625
 
626
  def precompute_freqs_cis(