Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesC
|
|
| 4 |
import gradio as gr
|
| 5 |
import os
|
| 6 |
import time
|
| 7 |
-
import numpy as np
|
| 8 |
|
| 9 |
# Load model and processor (runs once on startup)
|
| 10 |
model_name = "ibm-granite/granite-speech-3.2-8b"
|
|
@@ -54,7 +53,7 @@ def transcribe_audio(audio_input):
|
|
| 54 |
else:
|
| 55 |
# File input: filepath string
|
| 56 |
logs.append(f"Processing file input: {audio_input}")
|
| 57 |
-
wav, sr = torchaudio.load(audio_input)
|
| 58 |
logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
|
| 59 |
|
| 60 |
# Convert to mono if stereo
|
|
@@ -71,10 +70,8 @@ def transcribe_audio(audio_input):
|
|
| 71 |
|
| 72 |
logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
wav_np = wav.squeeze().numpy()
|
| 77 |
-
logs.append(f"Audio array shape for processor: {wav_np.shape}")
|
| 78 |
|
| 79 |
# Create text prompt
|
| 80 |
chat = [
|
|
@@ -92,19 +89,15 @@ def transcribe_audio(audio_input):
|
|
| 92 |
chat, tokenize=False, add_generation_prompt=True
|
| 93 |
)
|
| 94 |
|
| 95 |
-
#
|
| 96 |
logs.append("Preparing model inputs")
|
| 97 |
model_inputs = speech_granite_processor(
|
| 98 |
-
text
|
| 99 |
-
|
| 100 |
-
|
| 101 |
return_tensors="pt",
|
| 102 |
).to(device)
|
| 103 |
|
| 104 |
-
# Verify audio tokens are present
|
| 105 |
-
if "audio_values" not in model_inputs:
|
| 106 |
-
logs.append(f"WARNING: No audio_values in model inputs. Keys present: {list(model_inputs.keys())}")
|
| 107 |
-
|
| 108 |
# Generate transcription
|
| 109 |
logs.append("Generating transcription")
|
| 110 |
model_outputs = speech_granite.generate(
|
|
@@ -117,16 +110,21 @@ def transcribe_audio(audio_input):
|
|
| 117 |
repetition_penalty=3.0,
|
| 118 |
length_penalty=1.0,
|
| 119 |
temperature=1.0,
|
|
|
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
# Extract the generated text (skipping input tokens)
|
| 123 |
logs.append("Processing output")
|
| 124 |
num_input_tokens = model_inputs["input_ids"].shape[-1]
|
| 125 |
-
new_tokens = model_outputs[0, num_input_tokens:]
|
| 126 |
|
| 127 |
-
output_text = tokenizer.
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
transcription = output_text.strip().upper()
|
| 130 |
logs.append(f"Transcription complete: {transcription[:50]}...")
|
| 131 |
|
| 132 |
except Exception as e:
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
import os
|
| 6 |
import time
|
|
|
|
| 7 |
|
| 8 |
# Load model and processor (runs once on startup)
|
| 9 |
model_name = "ibm-granite/granite-speech-3.2-8b"
|
|
|
|
| 53 |
else:
|
| 54 |
# File input: filepath string
|
| 55 |
logs.append(f"Processing file input: {audio_input}")
|
| 56 |
+
wav, sr = torchaudio.load(audio_input, normalize=True)
|
| 57 |
logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
|
| 58 |
|
| 59 |
# Convert to mono if stereo
|
|
|
|
| 70 |
|
| 71 |
logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
|
| 72 |
|
| 73 |
+
# Verify audio format matches what the model expects
|
| 74 |
+
assert wav.shape[0] == 1 and sr == 16000, "Audio must be mono and 16kHz"
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Create text prompt
|
| 77 |
chat = [
|
|
|
|
| 89 |
chat, tokenize=False, add_generation_prompt=True
|
| 90 |
)
|
| 91 |
|
| 92 |
+
# CRITICAL CHANGE: Pass text and waveform directly to processor (don't pass audio as named param)
|
| 93 |
logs.append("Preparing model inputs")
|
| 94 |
model_inputs = speech_granite_processor(
|
| 95 |
+
text,
|
| 96 |
+
wav,
|
| 97 |
+
device=device, # Explicitly set device
|
| 98 |
return_tensors="pt",
|
| 99 |
).to(device)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# Generate transcription
|
| 102 |
logs.append("Generating transcription")
|
| 103 |
model_outputs = speech_granite.generate(
|
|
|
|
| 110 |
repetition_penalty=3.0,
|
| 111 |
length_penalty=1.0,
|
| 112 |
temperature=1.0,
|
| 113 |
+
bos_token_id=tokenizer.bos_token_id,
|
| 114 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 115 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 116 |
)
|
| 117 |
|
| 118 |
# Extract the generated text (skipping input tokens)
|
| 119 |
logs.append("Processing output")
|
| 120 |
num_input_tokens = model_inputs["input_ids"].shape[-1]
|
| 121 |
+
new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
|
| 122 |
|
| 123 |
+
output_text = tokenizer.batch_decode(
|
| 124 |
+
new_tokens, add_special_tokens=False, skip_special_tokens=True
|
| 125 |
+
)
|
| 126 |
|
| 127 |
+
transcription = output_text[0].strip().upper()
|
| 128 |
logs.append(f"Transcription complete: {transcription[:50]}...")
|
| 129 |
|
| 130 |
except Exception as e:
|