Remove stray cuda call
Browse filesA stray cuda call is preventing this model from being run on machines without a GPU.
This has been changed to a dynamically-chosen device allocation, to match input_ids' location.
- modeling_aria.py +3 -2
modeling_aria.py
CHANGED
|
@@ -343,6 +343,7 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 343 |
if self.causal_mask is None:
|
| 344 |
self.causal_mask = precompute_causal_mask(
|
| 345 |
max_seq_len=self.model_config.max_seq_len,
|
|
|
|
| 346 |
).to(input_ids.device)
|
| 347 |
|
| 348 |
if self.freqs_cis is None:
|
|
@@ -617,10 +618,10 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
|
| 617 |
)
|
| 618 |
|
| 619 |
|
| 620 |
-
def precompute_causal_mask(max_seq_len: int):
|
| 621 |
return torch.tril(
|
| 622 |
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
|
| 623 |
-
).
|
| 624 |
|
| 625 |
|
| 626 |
def precompute_freqs_cis(
|
|
|
|
| 343 |
if self.causal_mask is None:
|
| 344 |
self.causal_mask = precompute_causal_mask(
|
| 345 |
max_seq_len=self.model_config.max_seq_len,
|
| 346 |
+
input_ids = input_ids
|
| 347 |
).to(input_ids.device)
|
| 348 |
|
| 349 |
if self.freqs_cis is None:
|
|
|
|
| 618 |
)
|
| 619 |
|
| 620 |
|
| 621 |
+
def precompute_causal_mask(max_seq_len: int, input_ids: torch.Tensor):
|
| 622 |
return torch.tril(
|
| 623 |
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
|
| 624 |
+
).to(input_ids.device)
|
| 625 |
|
| 626 |
|
| 627 |
def precompute_freqs_cis(
|