Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os, glob, pydub, time | |
| from pytube import YouTube | |
| import torch, torchaudio | |
| import yaml | |
| import matplotlib.pyplot as plt | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| import torchaudio.transforms as T | |
| from src.models import models | |
| from st_audiorec import st_audiorec | |
| from pathlib import Path | |
| import numpy as np | |
| import subprocess | |
| # ๋ช ๋ น์ด ์คํ | |
| command = "apt-get update" | |
| process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # ๋ช ๋ น์ด ์คํ ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
| stdout, stderr = process.communicate() | |
| print(stdout, stderr) | |
| command = "apt-get install sox libsox-dev -y" | |
| process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # ๋ช ๋ น์ด ์คํ ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
| stdout, stderr = process.communicate() | |
| print(stdout, stderr) | |
| from twilio.base.exceptions import TwilioRestException | |
| from twilio.rest import Client | |
| import queue | |
| def get_ice_servers(): | |
| """Use Twilio's TURN server because Streamlit Community Cloud has changed | |
| its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501 | |
| We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too, | |
| but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501 | |
| See https://github.com/whitphx/streamlit-webrtc/issues/1213 | |
| """ | |
| # Ref: https://www.twilio.com/docs/stun-turn/api | |
| try: | |
| account_sid = os.environ["TWILIO_ACCOUNT_SID"] | |
| auth_token = os.environ["TWILIO_AUTH_TOKEN"] | |
| except KeyError: | |
| return [{"urls": ["stun:stun.l.google.com:19302"]}] | |
| client = Client(account_sid, auth_token) | |
| try: | |
| token = client.tokens.create() | |
| except TwilioRestException as e: | |
| st.warning( | |
| f"Error occurred while accessing Twilio API. Fallback to a free STUN server from Google. ({e})" # noqa: E501 | |
| ) | |
| return [{"urls": ["stun:stun.l.google.com:19302"]}] | |
| return token.ice_servers | |
| from streamlit_webrtc import webrtc_streamer | |
| from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
| import subprocess | |
| from pydub import AudioSegment | |
| from pyannote.audio import Pipeline | |
| import soundfile as sf | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| # Replace with your actual Hugging Face API token | |
| huggingface_token = os.environ["key"] | |
| pipeline = Pipeline.from_pretrained("pyannote/[email protected]", | |
| use_auth_token=huggingface_token).to(device) | |
| output_directory = '/MP3_Split' | |
| def split_by_speaker(file_path, output_dir): | |
| # Load the MP3 file | |
| audio = AudioSegment.from_mp3(file_path) | |
| # Convert audio to wav format (PyAnnote requires wav format) | |
| wav_path = file_path.replace('.mp3', '.wav') | |
| audio.export(wav_path, format="wav") | |
| # Perform speaker diarization | |
| diarization = pipeline(wav_path) | |
| audio_0_2_4 = AudioSegment.silent(duration=5) | |
| audio_1_3_5 = AudioSegment.silent(duration=5) | |
| # Split the audio based on diarization results | |
| base_filename = os.path.splitext(os.path.basename(file_path))[0] | |
| for i, (segment, _, speaker) in enumerate(diarization.itertracks(yield_label=True)): | |
| # Extract segment | |
| start_time = segment.start * 1000 # PyAnnote uses seconds, pydub uses milliseconds | |
| end_time = segment.end * 1000 | |
| audio_segment = audio[start_time:end_time] | |
| # Save segment as a separate MP3 file | |
| if i == 0: | |
| audio_0_2_4 += audio_segment | |
| elif i == 5: | |
| audio_1_3_5 += audio_segment | |
| os.makedirs(output_dir, exist_ok=True) | |
| audio_0_2_4.export(os.path.join(output_dir, f"{0}_speaker.mp3"), format="mp3") | |
| audio_1_3_5.export(os.path.join(output_dir, f"{1}_speaker.mp3"), format="mp3") | |
| def clear_files_in_directory(directory): | |
| if os.path.exists(directory): | |
| for filename in os.listdir(directory): | |
| file_path = os.path.join(directory, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| clear_files_in_directory(file_path) | |
| os.rmdir(file_path) # ํ์ ๋๋ ํ ๋ฆฌ๋ฅผ ๋น์ด ํ ์ญ์ | |
| except Exception as e: | |
| print(f'ํ์ผ {file_path} ์ญ์ ์ค ์๋ฌ ๋ฐ์: {e}') | |
| else: | |
| print(f'๋๋ ํ ๋ฆฌ {directory}๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค.') | |
| # ์ ์ฒ๋ฆฌ ํจ์ | |
| SAMPLING_RATE = 16_000 | |
| def apply_preprocessing( | |
| waveform, | |
| sample_rate, | |
| ): | |
| if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1: | |
| waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE) | |
| # Stereo to mono | |
| if waveform.dim() > 1 and waveform.shape[0] > 1: | |
| waveform = waveform[:1, ...] | |
| waveform, sample_rate = apply_trim(waveform, sample_rate) | |
| waveform = apply_pad(waveform, 480_000) | |
| return waveform, sample_rate | |
| def resample_wave(waveform, sample_rate, target_sample_rate): | |
| waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( | |
| waveform, sample_rate, [["rate", f"{target_sample_rate}"]] | |
| ) | |
| return waveform, sample_rate | |
| def apply_trim(waveform, sample_rate): | |
| ( | |
| waveform_trimmed, | |
| sample_rate_trimmed, | |
| ) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, [["silence", "1", "0.2", "1%", "-1", "0.2", "1%"]]) | |
| if waveform_trimmed.size()[1] > 0: | |
| waveform = waveform_trimmed | |
| sample_rate = sample_rate_trimmed | |
| return waveform, sample_rate | |
| def apply_pad(waveform, cut): | |
| """Pad wave by repeating signal until `cut` length is achieved.""" | |
| waveform = waveform.squeeze(0) | |
| waveform_len = waveform.shape[0] | |
| if waveform_len >= cut: | |
| return waveform[:cut] | |
| # need to pad | |
| num_repeats = int(cut / waveform_len) + 1 | |
| padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0] | |
| return padded_waveform | |
| # | |
| # | |
| # | |
| # ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋ฉ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| with open('augmentation_ko_whisper_frontend_lcnn_mfcc.yaml', 'r') as f: | |
| model_config = yaml.safe_load(f) | |
| model_paths = model_config["checkpoint"]["path"] | |
| model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"] | |
| model = models.get_model( | |
| model_name=model_name, | |
| config=model_parameters, | |
| device=device, | |
| ) | |
| model.load_state_dict(torch.load(model_paths, map_location=torch.device('cpu'))) | |
| model = model.to(device) | |
| model.eval() | |
| # YouTube ๋น๋์ค ๋ค์ด๋ก๋ ๋ฐ ์ค๋์ค ์ถ์ถ ํจ์ | |
| def download_youtube_audio(youtube_url, output_path="temp"): | |
| yt = YouTube(youtube_url) | |
| audio_stream = yt.streams.get_audio_only() | |
| output_file = audio_stream.download(output_path=output_path) | |
| title = audio_stream.default_filename | |
| return output_file, title | |
| # URL๋ก๋ถํฐ ์์ธก | |
| def pred_from_url(youtube_url, segment_length=30): | |
| global model | |
| audio_path, title = download_youtube_audio(youtube_url) | |
| print(f"- [{title}]์ ๋ํด ์คํ\n\n") | |
| waveform, sample_rate = torchaudio.load(audio_path, normalize=True) | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=48000, new_freq=SAMPLING_RATE) | |
| if waveform.size(0) > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| num_samples_per_segment = int(segment_length * sample_rate) | |
| total_samples = waveform.size(1) | |
| if total_samples <= num_samples_per_segment: | |
| num_samples_per_segment = total_samples | |
| num_segments = 1 | |
| else: | |
| num_segments = total_samples // num_samples_per_segment | |
| preds = [] | |
| print("์ค๋์ค chunk ๋ถํ ์ :", num_segments) | |
| for i in range(num_segments): | |
| start_sample = i * num_samples_per_segment | |
| end_sample = start_sample + num_samples_per_segment | |
| segment = waveform[:, start_sample:end_sample] | |
| segment, sample_rate = apply_preprocessing(segment, sample_rate) | |
| pred = model(segment.unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| preds.append(pred.item()) | |
| avg_pred = torch.tensor(preds).mean().item() | |
| os.remove(audio_path) | |
| output = "" | |
| if int(avg_pred+0.5): | |
| output = "fake" | |
| else: | |
| output = "real" | |
| return f"""์์ธก:{output} | |
| {(avg_pred*100):.2f}% ํ๋ฅ ๋ก fake์ ๋๋ค.""" | |
| # ํ์ผ๋ก๋ถํฐ ์์ธก | |
| def pred_from_file(file_path, segment_length=30): | |
| global model | |
| clear_files_in_directory(output_directory) | |
| split_by_speaker(file_path, output_directory) | |
| output = "" | |
| for p in list(Path(output_directory).glob("*.mp3")): | |
| waveform, sample_rate = torchaudio.load(p, normalize=True) | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=48000, new_freq=sample_rate) | |
| if waveform.size(0) > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| num_samples_per_segment = int(segment_length * sample_rate) | |
| total_samples = waveform.size(1) | |
| if total_samples <= num_samples_per_segment: | |
| num_samples_per_segment = total_samples | |
| num_segments = 1 | |
| else: | |
| num_segments = total_samples // num_samples_per_segment | |
| preds = [] | |
| print(f"ํ์ {p.name}์ ์ค๋์ค chunk ๋ถํ ์ : {num_segments}") | |
| for i in range(num_segments): | |
| # ๊ฐ ๊ตฌ๊ฐ์ ๋ํ ์ถ๋ก ์งํ | |
| start_sample = i * num_samples_per_segment | |
| end_sample = start_sample + num_samples_per_segment | |
| segment = waveform[:, start_sample:end_sample] | |
| segment, sample_rate = apply_preprocessing(segment, sample_rate) | |
| pred = model(segment.unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| preds.append(pred.item()) | |
| avg_pred = torch.tensor(preds).mean().item() | |
| output += f"ํ์ {p.name} : {(avg_pred*100):.2f}% ํ๋ฅ ๋ก fake์ ๋๋ค.\n\n" | |
| return output | |
| def pred_from_realtime_audio(data): | |
| global model | |
| data = torch.tensor(data, dtype=torch.float32) | |
| data = data.unsqueeze(0) | |
| data = torchaudio.functional.resample(data, orig_freq=48000, new_freq=SAMPLING_RATE) | |
| data = data / torch.max(torch.abs(data)) | |
| mean = torch.mean(data) | |
| std = torch.std(data) | |
| data = (data - mean) / std | |
| data, sample_rate = apply_preprocessing(data, SAMPLING_RATE) | |
| pred = model(torch.tensor(data).unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| return pred.item() | |
| # Streamlit UI | |
| st.title("DeepFake Detection Demo") | |
| st.markdown("whisper-LCNN (using MLAAD, MAILABS, aihub ๊ฐ์ฑ ๋ฐ ๋ฐํ์คํ์ผ ๋์ ๊ณ ๋ ค ์์ฑํฉ์ฑ ๋ฐ์ดํฐ, ์์ฒด ์์ง ๋ฐ ์์ฑํ KoAAD)") | |
| st.markdown("github : https://github.com/ldh-Hoon/ko_deepfake-whisper-features") | |
| tab1, tab2, tab3 = st.tabs(["YouTube URL", "ํ์ผ ์ ๋ก๋", "์ค์๊ฐ ์ค๋์ค ์ ๋ ฅ"]) | |
| example_urls_fake = [ | |
| "https://youtu.be/ha3gfD7S0_E", | |
| "https://youtu.be/5lmJ0Rhr-ec", | |
| "https://youtu.be/q6ra0KDgVbg", | |
| "https://youtu.be/hfmm1Oo6SSY?feature=shared" | |
| ] | |
| example_urls_real = [ | |
| "https://youtu.be/54y1sYLZjqs", | |
| "https://youtu.be/7qT0Stb3QNY", | |
| ] | |
| if 'youtube_url' not in st.session_state: | |
| st.session_state['youtube_url'] = '' | |
| with tab1: | |
| st.markdown("""example | |
| >fake: | |
| """) | |
| for url in example_urls_fake: | |
| if st.button(url, key=url): | |
| st.session_state.youtube_url = url | |
| st.markdown(""">real: | |
| """) | |
| for url in example_urls_real: | |
| if st.button(url, key=url): | |
| st.session_state.youtube_url = url | |
| youtube_url = st.text_input("YouTube URL", value=st.session_state.youtube_url) | |
| if youtube_url: | |
| result = pred_from_url(youtube_url) # ์ฌ๊ธฐ์ pred_from_url ํจ์ ์ ์๊ฐ ํ์ํฉ๋๋ค. | |
| st.text_area("๊ฒฐ๊ณผ", value=result, height=150) | |
| st.video(youtube_url) | |
| with tab2: | |
| file = st.file_uploader("์ค๋์ค ํ์ผ ์ ๋ก๋", type=['mp3', 'wav']) | |
| if file is not None and st.button("RUN ํ์ผ"): | |
| # ์์ ํ์ผ ์ ์ฅ | |
| with open(file.name, "wb") as f: | |
| f.write(file.getbuffer()) | |
| result = pred_from_file(file.name) | |
| st.text_area("๊ฒฐ๊ณผ", value=result, height=150) | |
| os.remove(file.name) # ์์ ํ์ผ ์ญ์ | |
| with tab3: | |
| p = st.empty() | |
| preds = [] | |
| fig, [ax_time, ax_freq] = plt.subplots(2, 1, gridspec_kw={"top": 1.5, "bottom": 0.2}) | |
| sound_window_len = 2000 # 5s | |
| sound_window_buffer = None | |
| webrtc_ctx = webrtc_streamer( | |
| key="sendonly-audio", | |
| mode=WebRtcMode.SENDONLY, | |
| audio_receiver_size=1024, | |
| rtc_configuration={"iceServers": get_ice_servers()}, | |
| media_stream_constraints={"audio": True}, | |
| ) | |
| while True: | |
| if webrtc_ctx.audio_receiver: | |
| try: | |
| audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1) | |
| except queue.Empty: | |
| break | |
| sound_chunk = pydub.AudioSegment.empty() | |
| for audio_frame in audio_frames: | |
| sound = pydub.AudioSegment( | |
| data=audio_frame.to_ndarray().tobytes(), | |
| sample_width=audio_frame.format.bytes, | |
| frame_rate=audio_frame.sample_rate, | |
| channels=len(audio_frame.layout.channels), | |
| ) | |
| sound_chunk += sound | |
| if len(sound_chunk) > 0: | |
| if sound_window_buffer is None: | |
| sound_window_buffer = pydub.AudioSegment.silent( | |
| duration=sound_window_len | |
| ) | |
| sound_window_buffer += sound_chunk | |
| if len(sound_window_buffer) > sound_window_len: | |
| sound_window_buffer = sound_window_buffer[-sound_window_len:] | |
| if sound_window_buffer: | |
| # Ref: https://own-search-and-study.xyz/2017/10/27/python%E3%82%92%E4%BD%BF%E3%81%A3%E3%81%A6%E9%9F%B3%E5%A3%B0%E3%83%87%E3%83%BC%E3%82%BF%E3%81%8B%E3%82%89%E3%82%B9%E3%83%9A%E3%82%AF%E3%83%88%E3%83%AD%E3%82%B0%E3%83%A9%E3%83%A0%E3%82%92%E4%BD%9C/ # noqa | |
| sound_window_buffer = sound_window_buffer.set_channels(1) # Stereo to mono | |
| sample = np.array(sound_window_buffer.get_array_of_samples()) | |
| preds.append(pred_from_realtime_audio(sample)) | |
| if len(preds) > 100: | |
| preds = preds[-100:] | |
| p.write(f"pred : {np.mean(preds)*100:.2f}%") | |
| else: | |
| break |