Bapt120 commited on
Commit
299e18a
·
verified ·
1 Parent(s): 01a806f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -2,12 +2,11 @@
2
  import subprocess
3
  import sys
4
 
5
- # CRITICAL: Import spaces FIRST before any CUDA initialization
6
- import spaces
7
 
8
- # Now we can import torch and other packages
9
  import torch
10
 
 
11
  # Install flash-attn for GPU only (after spaces import)
12
  if torch.cuda.is_available():
13
  print("CUDA detected - installing flash-attn for optimal GPU performance...")
@@ -99,8 +98,13 @@ def extract_text_from_image(image, temperature=0.2):
99
  return_tensors="pt"
100
  )
101
 
102
- # Move inputs to device
103
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
 
 
 
 
 
104
 
105
  # Generate text with appropriate settings
106
  with torch.no_grad(): # Disable gradients for inference
 
2
  import subprocess
3
  import sys
4
 
 
 
5
 
6
+ import spaces
7
  import torch
8
 
9
+
10
  # Install flash-attn for GPU only (after spaces import)
11
  if torch.cuda.is_available():
12
  print("CUDA detected - installing flash-attn for optimal GPU performance...")
 
98
  return_tensors="pt"
99
  )
100
 
101
+ # Move inputs to device AND convert to the correct dtype
102
+ inputs = {
103
+ k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
104
+ else v.to(device) if isinstance(v, torch.Tensor)
105
+ else v
106
+ for k, v in inputs.items()
107
+ }
108
 
109
  # Generate text with appropriate settings
110
  with torch.no_grad(): # Disable gradients for inference