Removed stray cuda call

#2
by justbruno - opened

A stray cuda call was preventing this model from being used on devices without a GPU or TPU.

The causal mask is allocated to input_ids.device upon return, as it should.

Thanks for pointing out this oversight - I just merged a fix.

By the way, if you intend to use this model for anything intensive, I'd reccomend checking out the training & inference (torch /w cudagraphs and MLX) implementations on the GitHub repo.

https://github.com/EleutherAI/aria

loubb changed pull request status to closed

Sign up or log in to comment