Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torchaudio | |
| import torch | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline | |
| import numpy as np | |
| import tempfile | |
| import os | |
| # 全域變數存儲模型 | |
| processor = None | |
| model = None | |
| asr_pipeline = None | |
| def load_model(): | |
| """載入 Breeze ASR 25 模型""" | |
| global processor, model, asr_pipeline | |
| try: | |
| processor = WhisperProcessor.from_pretrained("MediaTek-Research/Breeze-ASR-25") | |
| model = WhisperForConditionalGeneration.from_pretrained("MediaTek-Research/Breeze-ASR-25") | |
| # 檢查是否有 CUDA | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device).eval() | |
| # 建立 pipeline | |
| asr_pipeline = AutomaticSpeechRecognitionPipeline( | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| chunk_length_s=0, | |
| device=device | |
| ) | |
| return f"✅ 模型載入成功!使用設備: {device}" | |
| except Exception as e: | |
| return f"❌ 模型載入失敗: {str(e)}" | |
| def preprocess_audio(audio_path): | |
| """音訊預處理""" | |
| # 載入音訊 | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # 轉為單聲道 | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0) | |
| waveform = waveform.squeeze().numpy() | |
| # 重採樣到 16kHz | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| waveform = resampler(torch.tensor(waveform)).numpy() | |
| return waveform | |
| def transcribe_audio(audio_input): | |
| """語音辨識主函數""" | |
| global asr_pipeline | |
| try: | |
| # 檢查模型是否已載入 | |
| if asr_pipeline is None: | |
| status = load_model() | |
| if "失敗" in status: | |
| return status, "", "", "" | |
| # 檢查音訊輸入 | |
| if audio_input is None: | |
| return "❌ 請先上傳音訊檔案或進行錄音", "", "", "" | |
| # 處理不同的音訊輸入格式 | |
| if isinstance(audio_input, str): | |
| # 檔案路徑 | |
| audio_path = audio_input | |
| elif isinstance(audio_input, tuple): | |
| # Gradio 錄音格式 (sample_rate, audio_data) | |
| sample_rate, audio_data = audio_input | |
| # 建立臨時檔案 | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| # 確保音訊數據格式正確 | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # 正規化音訊 | |
| if audio_data.max() > 1.0: | |
| audio_data = audio_data / 32768.0 | |
| # 儲存為 wav 檔案 | |
| torchaudio.save(tmp_file.name, torch.tensor(audio_data).unsqueeze(0), sample_rate) | |
| audio_path = tmp_file.name | |
| else: | |
| return "❌ 不支援的音訊格式", "", "", "" | |
| # 預處理音訊 | |
| waveform = preprocess_audio(audio_path) | |
| # 執行語音辨識 | |
| result = asr_pipeline(waveform, return_timestamps=True) | |
| # 清理臨時檔案 | |
| if isinstance(audio_input, tuple) and os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| # 格式化結果 | |
| transcription = result["text"].strip() | |
| # 格式化時間戳記顯示 | |
| formatted_text = "" | |
| pure_text = "" | |
| srt_text = "" | |
| if "chunks" in result and result["chunks"]: | |
| for i, chunk in enumerate(result["chunks"], 1): | |
| start_time = chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0 | |
| end_time = chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0 | |
| text = chunk['text'].strip() | |
| if text: # 只處理非空文字 | |
| # 格式化顯示文字 | |
| #formatted_text += f"[{start_time:.2f}s - {end_time:.2f}s]: {text}\n" | |
| # 純文字(不含時間戳記) | |
| pure_text += f"{text}\n" | |
| # SRT 格式 | |
| start_srt = f"{int(start_time//3600):02d}:{int((start_time%3600)//60):02d}:{int(start_time%60):02d},{int((start_time%1)*1000):03d}" | |
| end_srt = f"{int(end_time//3600):02d}:{int((end_time%3600)//60):02d}:{int(end_time%60):02d},{int((end_time%1)*1000):03d}" | |
| srt_text += f"{i}\n{start_srt} --> {end_srt}\n{text}\n\n" | |
| else: | |
| # 如果沒有時間戳記,只顯示文字 | |
| #formatted_text = transcription | |
| pure_text = transcription | |
| srt_text = f"1\n00:00:00,000 --> 00:00:10,000\n{transcription}\n\n" | |
| return "✅ 辨識完成", pure_text.strip(), srt_text.strip() | |
| except Exception as e: | |
| return f"❌ 辨識過程發生錯誤: {str(e)}", "" | |
| def clear_all(): | |
| """清除所有內容""" | |
| return None, "🔄 已清除所有內容", "", "", "" | |
| # 建立 Gradio 介面 | |
| with gr.Blocks(title="語音辨識系統", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎤 語音辨識系統 - Breeze ASR 25 | |
| ### 功能特色: | |
| - 🔧 使用 Breeze ASR 25 模型,專為繁體中文優化 | |
| - ⏰ 顯示時間戳記 | |
| - 🌐 強化中英混用辨識能力 | |
| - 感謝[MediaTek-Research/Breeze-ASR-25](https://huggingface.co/MediaTek-Research/Breeze-ASR-25) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # 音訊輸入區域 | |
| gr.Markdown("### 📂 音訊輸入") | |
| with gr.Tab("檔案上傳"): | |
| audio_file = gr.Audio( | |
| sources=["upload"], | |
| label="上傳音訊檔案", | |
| type="filepath" | |
| ) | |
| # 控制按鈕 | |
| with gr.Row(): | |
| transcribe_btn = gr.Button("🚀 開始辨識", variant="primary", size="lg") | |
| clear_btn = gr.Button("🗑️ 清除", variant="secondary") | |
| with gr.Column(scale=1): | |
| # 狀態顯示 | |
| status_output = gr.Textbox( | |
| label="📊 狀態", | |
| placeholder="等待操作...", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| # 純文字結果 | |
| pure_text_output = gr.Textbox( | |
| label="📄 純文字結果", | |
| placeholder="純文字結果...", | |
| lines=4, | |
| max_lines=10, | |
| show_copy_button=True | |
| ) | |
| # SRT 字幕格式 | |
| srt_output = gr.Textbox( | |
| label="🎬 SRT 字幕格式", | |
| placeholder="SRT 格式字幕...", | |
| lines=6, | |
| max_lines=15, | |
| show_copy_button=True | |
| ) | |
| # 修正事件綁定 | |
| def transcribe_wrapper(audio_file_val, audio_mic_val): | |
| audio_input = audio_file_val if audio_file_val else audio_mic_val | |
| return transcribe_audio(audio_input) | |
| transcribe_btn.click( | |
| fn=transcribe_wrapper, | |
| inputs=[audio_file], | |
| outputs=[status_output, pure_text_output, srt_output] | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| outputs=[audio_file, status_output, pure_text_output, srt_output] | |
| ) | |
| # 啟動應用 | |
| if __name__ == "__main__": | |
| demo.launch() |