Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import numpy as np | |
| import traceback | |
| import io | |
| import time | |
| # Attempt to import SNAC (should work if requirements.txt is correct) | |
| try: | |
| from snac import SNAC | |
| print("SNAC module imported successfully.") | |
| except ImportError as e: | |
| print(f"Error importing SNAC: {e}") | |
| # Raise a more informative error if SNAC isn't installed | |
| raise ImportError("Could not import SNAC. Make sure 'snac' is listed in requirements.txt and installed correctly.") from e | |
| # --- Configuration --- | |
| TARGET_SR = 32000 # SNAC operates at 32kHz | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| # --- Load Model (Load once globally) --- | |
| snac_model = None | |
| try: | |
| print("Loading SNAC model...") | |
| start_time = time.time() | |
| snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_32khz") | |
| snac_model = snac_model.to(DEVICE) | |
| snac_model.eval() # Set model to evaluation mode | |
| end_time = time.time() | |
| print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.") | |
| except Exception as e: | |
| print(f"FATAL: Error loading SNAC model: {e}") | |
| print(traceback.format_exc()) | |
| # If the model fails to load, the app can't function. | |
| # Gradio will likely show an error, but we print specifics here. | |
| # --- Main Processing Function --- | |
| def process_audio(audio_filepath): | |
| """ | |
| Loads, resamples, encodes, decodes audio using SNAC, and returns results. | |
| """ | |
| if snac_model is None: | |
| return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio." | |
| if audio_filepath is None: | |
| return None, None, None, "Please upload an audio file." | |
| logs = ["--- Starting Audio Processing ---"] | |
| try: | |
| # 1. Load Audio | |
| logs.append(f"Loading audio file: {audio_filepath}") | |
| load_start = time.time() | |
| original_waveform, original_sr = torchaudio.load(audio_filepath) | |
| load_end = time.time() | |
| logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s") | |
| # Ensure float32 | |
| original_waveform = original_waveform.to(dtype=torch.float32) | |
| # Handle multi-channel audio: Use the first channel | |
| if original_waveform.shape[0] > 1: | |
| logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.") | |
| original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency initially | |
| # --- Prepare Original for Playback --- | |
| # Gradio Audio component expects (sample_rate, numpy_array) | |
| # Ensure numpy array is 1D or 2D [channels, samples] | |
| original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) # Squeeze removes channel dim if 1 | |
| logs.append("Prepared original audio for playback.") | |
| # 2. Resample if necessary | |
| resample_start = time.time() | |
| if original_sr != TARGET_SR: | |
| logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...") | |
| resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) # Resampler on same device | |
| waveform_to_encode = resampler(original_waveform) | |
| logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}") | |
| else: | |
| logs.append("Waveform is already at the target sample rate (32kHz).") | |
| waveform_to_encode = original_waveform | |
| resample_end = time.time() | |
| logs.append(f"Resampling time: {resample_end - resample_start:.2f}s") | |
| # --- Prepare Resampled for Playback --- | |
| resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy()) | |
| logs.append("Prepared resampled audio for playback.") | |
| # 3. Prepare for SNAC Encoding (add batch dim, move to device) | |
| # Input should be [Batch, Channel, Time] = [1, 1, Time] | |
| # waveform_to_encode should currently be [1, Time] after channel selection/resampling | |
| waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) # Add batch dimension -> [1, 1, Time] | |
| logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}") | |
| # 4. Encode Audio using SNAC | |
| logs.append("Encoding audio with snac_model.encode()...") | |
| encode_start = time.time() | |
| with torch.inference_mode(): | |
| codes = snac_model.encode(waveform_batch) | |
| encode_end = time.time() | |
| if not codes or not all(isinstance(c, torch.Tensor) for c in codes): | |
| log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}" | |
| if isinstance(codes, list): | |
| log_msg += f" with types {[type(c) for c in codes]}" | |
| logs.append(log_msg) | |
| raise ValueError(log_msg) | |
| logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s") | |
| for i, layer_codes in enumerate(codes): | |
| logs.append(f" Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}") | |
| # 5. Decode the Codes using SNAC | |
| logs.append("Decoding the generated codes with snac_model.decode()...") | |
| decode_start = time.time() | |
| with torch.inference_mode(): | |
| reconstructed_waveform = snac_model.decode(codes) # codes are already on DEVICE | |
| decode_end = time.time() | |
| logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s") | |
| # 6. Prepare Reconstructed Audio for Playback | |
| # Output is [Batch, 1, Time]. Move to CPU, remove Batch/Channel, convert to NumPy. | |
| reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() # Squeeze removes Batch and Channel dims | |
| logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}") | |
| reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np) | |
| logs.append("\n--- Audio Processing Completed Successfully ---") | |
| return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs) | |
| except Exception as e: | |
| logs.append("\n--- An Error Occurred ---") | |
| logs.append(f"Error Type: {type(e).__name__}") | |
| logs.append(f"Error Details: {e}") | |
| logs.append("\n--- Traceback ---") | |
| logs.append(traceback.format_exc()) | |
| # Return None for audio components on error, and the detailed log | |
| return None, None, None, "\n".join(logs) | |
| # --- Gradio Interface --- | |
| DESCRIPTION = """ | |
| This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_32khz`). | |
| 1. Upload an audio file (wav, mp3, flac, etc.). | |
| 2. The audio will be automatically resampled to 32kHz if needed. | |
| 3. The 32kHz audio is encoded into discrete codes by SNAC. | |
| 4. These codes are then decoded back into audio by SNAC. | |
| 5. You can listen to the original, the 32kHz version (if resampled), and the final reconstructed audio. | |
| **Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed. | |
| """ | |
| iface = gr.Interface( | |
| fn=process_audio, | |
| inputs=gr.Audio(type="filepath", label="Upload Audio File"), | |
| outputs=[ | |
| gr.Audio(label="Original Audio"), | |
| gr.Audio(label="Resampled Audio (32kHz Input to SNAC)"), | |
| gr.Audio(label="Reconstructed Audio (Output from SNAC)"), | |
| gr.Textbox(label="Log Output", lines=15) | |
| ], | |
| title="SNAC Audio Codec Demo (32kHz)", | |
| description=DESCRIPTION, | |
| examples=[ | |
| # Add paths to example audio files if you upload some to your Space repo | |
| # ["examples/example1.wav"], | |
| # ["examples/example2.mp3"], | |
| ], | |
| cache_examples=False # Disable caching if examples change or have issues | |
| ) | |
| if __name__ == "__main__": | |
| if snac_model is None: | |
| print("Cannot launch Gradio interface because SNAC model failed to load.") | |
| else: | |
| print("Launching Gradio Interface...") | |
| iface.launch() |