Removed stray cuda call
#2
by
justbruno
- opened
A stray cuda call was preventing this model from being used on devices without a GPU or TPU.
The causal mask is allocated to input_ids.device upon return, as it should.
Thanks for pointing out this oversight - I just merged a fix.
By the way, if you intend to use this model for anything intensive, I'd reccomend checking out the training & inference (torch /w cudagraphs and MLX) implementations on the GitHub repo.
loubb
changed pull request status to
closed