ACloudCenter commited on
Commit
978fbbf
·
1 Parent(s): dc88c51

fix: modify reqs with GPU decorator for HF spaces

Browse files
Files changed (2) hide show
  1. app.py +77 -16
  2. requirements.txt +7 -9
app.py CHANGED
@@ -1,24 +1,85 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import torch
4
- from nemo.collections.speechlm2 import SALM
5
  import spaces
 
 
6
 
7
- if torch.cuda.is_available():
8
- device = torch.device("cuda")
9
- else:
10
- device = torch.device("cpu")
11
 
12
- # Initialize the ASR model which is based on the "nvidia/canary-qwen-2.5b" architecture and uses NVIDIA's NeMo framework
13
  model = SALM.from_pretrained("nvidia/canary-qwen-2.5b").bfloat16().eval().to(device)
14
- transcriber = pipeline("automatic-speech-recognition", model = model)
15
 
16
- # Transcribe audio file using NeMo's transcribe class
17
- def transcribe_audio(audio_file):
18
- transcript = transcriber([audio_file])[0].text
19
- return transcript
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- demo = gr.Interface(
22
- fn=transcribe_audio,
23
- inputs=gr.Audio(source="upload", type="filepath"),
24
- outputs=gr.Textbox())
 
1
  import gradio as gr
 
2
  import torch
 
3
  import spaces
4
+ from lhotse import Recording
5
+ from nemo.collections.speechlm2 import SALM
6
 
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ SAMPLE_RATE = 16000
 
 
9
 
 
10
  model = SALM.from_pretrained("nvidia/canary-qwen-2.5b").bfloat16().eval().to(device)
 
11
 
12
+ @spaces.GPU
13
+ def transcribe_audio(audio_filepath):
14
+ if audio_filepath is None:
15
+ return "Please upload an audio file", ""
16
+
17
+ rec = Recording.from_file(audio_filepath, recording_id="temp")
18
+ cut = rec.resample(SAMPLE_RATE).to_cut()
19
+ if cut.num_channels > 1:
20
+ cut = cut.to_mono(mono_downmix=True)
21
+
22
+ audio, audio_lens = cut.load_audio()
23
+
24
+ with torch.inference_mode():
25
+ output_ids = model.generate(
26
+ prompts=[[{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}]],
27
+ audios=torch.as_tensor(audio).unsqueeze(0).to(device),
28
+ audio_lens=torch.as_tensor([audio_lens]).to(device),
29
+ max_new_tokens=256,
30
+ )
31
+
32
+ transcript = model.tokenizer.ids_to_text(output_ids[0].cpu())
33
+ return transcript, transcript
34
+
35
+ @spaces.GPU
36
+ def answer_question(transcript, question):
37
+ if not transcript:
38
+ return "Please transcribe audio first"
39
+
40
+ with torch.inference_mode(), model.llm.disable_adapter():
41
+ output_ids = model.generate(
42
+ prompts=[[{"role": "user", "content": f"{question}\n\n{transcript}"}]],
43
+ max_new_tokens=512,
44
+ )
45
+
46
+ answer = model.tokenizer.ids_to_text(output_ids[0].cpu())
47
+ answer = answer.split("<|im_start|>assistant")[-1]
48
+ return answer.strip()
49
+
50
+ with gr.Blocks(title="Canary-Qwen Transcriber & Q&A") as demo:
51
+ gr.Markdown("# Canary-Qwen Transcriber with Q&A")
52
+ gr.Markdown("Upload audio to transcribe, then ask questions about it!")
53
+
54
+ with gr.Row():
55
+ with gr.Column():
56
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input")
57
+ transcribe_btn = gr.Button("Transcribe", variant="primary")
58
+
59
+ with gr.Column():
60
+ transcript_output = gr.Textbox(label="Transcript", lines=8)
61
+
62
+ transcript_state = gr.State()
63
+
64
+ with gr.Row():
65
+ with gr.Column():
66
+ question_input = gr.Textbox(label="Ask a question about the transcript", placeholder="What is the main topic?")
67
+ ask_btn = gr.Button("Ask", variant="primary")
68
+
69
+ with gr.Column():
70
+ answer_output = gr.Textbox(label="Answer", lines=4)
71
+
72
+ transcribe_btn.click(
73
+ fn=transcribe_audio,
74
+ inputs=[audio_input],
75
+ outputs=[transcript_output, transcript_state]
76
+ )
77
+
78
+ ask_btn.click(
79
+ fn=answer_question,
80
+ inputs=[transcript_state, question_input],
81
+ outputs=[answer_output]
82
+ )
83
 
84
+ demo.queue()
85
+ demo.launch()
 
 
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
- gradio
2
- transformers
3
- spaces
4
- nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git
5
- peft
6
- sacrebleu
7
- seaborn
8
- --extra-index-url https://download.pytorch.org/whl/cu113
9
- torch
 
1
+ gradio
2
+ spaces
3
+ nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git
4
+ lhotse
5
+ peft
6
+ --extra-index-url https://download.pytorch.org/whl/cu113
7
+ torch