Gapeleon commited on
Commit
b0e4499
·
verified ·
1 Parent(s): eb92e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -4,7 +4,6 @@ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesC
4
  import gradio as gr
5
  import os
6
  import time
7
- import numpy as np
8
 
9
  # Load model and processor (runs once on startup)
10
  model_name = "ibm-granite/granite-speech-3.2-8b"
@@ -54,7 +53,7 @@ def transcribe_audio(audio_input):
54
  else:
55
  # File input: filepath string
56
  logs.append(f"Processing file input: {audio_input}")
57
- wav, sr = torchaudio.load(audio_input)
58
  logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
59
 
60
  # Convert to mono if stereo
@@ -71,10 +70,8 @@ def transcribe_audio(audio_input):
71
 
72
  logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
73
 
74
- # Convert to numpy array as expected by the processor
75
- # Make sure it's in the format [time]
76
- wav_np = wav.squeeze().numpy()
77
- logs.append(f"Audio array shape for processor: {wav_np.shape}")
78
 
79
  # Create text prompt
80
  chat = [
@@ -92,19 +89,15 @@ def transcribe_audio(audio_input):
92
  chat, tokenize=False, add_generation_prompt=True
93
  )
94
 
95
- # Compute audio embeddings
96
  logs.append("Preparing model inputs")
97
  model_inputs = speech_granite_processor(
98
- text=text,
99
- audio=wav_np, # Pass numpy array in format [time]
100
- sampling_rate=sr,
101
  return_tensors="pt",
102
  ).to(device)
103
 
104
- # Verify audio tokens are present
105
- if "audio_values" not in model_inputs:
106
- logs.append(f"WARNING: No audio_values in model inputs. Keys present: {list(model_inputs.keys())}")
107
-
108
  # Generate transcription
109
  logs.append("Generating transcription")
110
  model_outputs = speech_granite.generate(
@@ -117,16 +110,21 @@ def transcribe_audio(audio_input):
117
  repetition_penalty=3.0,
118
  length_penalty=1.0,
119
  temperature=1.0,
 
 
 
120
  )
121
 
122
  # Extract the generated text (skipping input tokens)
123
  logs.append("Processing output")
124
  num_input_tokens = model_inputs["input_ids"].shape[-1]
125
- new_tokens = model_outputs[0, num_input_tokens:]
126
 
127
- output_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
 
128
 
129
- transcription = output_text.strip().upper()
130
  logs.append(f"Transcription complete: {transcription[:50]}...")
131
 
132
  except Exception as e:
 
4
  import gradio as gr
5
  import os
6
  import time
 
7
 
8
  # Load model and processor (runs once on startup)
9
  model_name = "ibm-granite/granite-speech-3.2-8b"
 
53
  else:
54
  # File input: filepath string
55
  logs.append(f"Processing file input: {audio_input}")
56
+ wav, sr = torchaudio.load(audio_input, normalize=True)
57
  logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
58
 
59
  # Convert to mono if stereo
 
70
 
71
  logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
72
 
73
+ # Verify audio format matches what the model expects
74
+ assert wav.shape[0] == 1 and sr == 16000, "Audio must be mono and 16kHz"
 
 
75
 
76
  # Create text prompt
77
  chat = [
 
89
  chat, tokenize=False, add_generation_prompt=True
90
  )
91
 
92
+ # CRITICAL CHANGE: Pass text and waveform directly to processor (don't pass audio as named param)
93
  logs.append("Preparing model inputs")
94
  model_inputs = speech_granite_processor(
95
+ text,
96
+ wav,
97
+ device=device, # Explicitly set device
98
  return_tensors="pt",
99
  ).to(device)
100
 
 
 
 
 
101
  # Generate transcription
102
  logs.append("Generating transcription")
103
  model_outputs = speech_granite.generate(
 
110
  repetition_penalty=3.0,
111
  length_penalty=1.0,
112
  temperature=1.0,
113
+ bos_token_id=tokenizer.bos_token_id,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ pad_token_id=tokenizer.pad_token_id,
116
  )
117
 
118
  # Extract the generated text (skipping input tokens)
119
  logs.append("Processing output")
120
  num_input_tokens = model_inputs["input_ids"].shape[-1]
121
+ new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
122
 
123
+ output_text = tokenizer.batch_decode(
124
+ new_tokens, add_special_tokens=False, skip_special_tokens=True
125
+ )
126
 
127
+ transcription = output_text[0].strip().upper()
128
  logs.append(f"Transcription complete: {transcription[:50]}...")
129
 
130
  except Exception as e: