Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Example Run Command: python make_supdata.py --ssl_model_ckpt_path <PATH TO CKPT> --manifest_path <PATH TO MANIFEST> | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| from multiprocessing import Pool | |
| from pathlib import Path | |
| import hydra.utils | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from omegaconf import open_dict | |
| from tqdm import tqdm | |
| from nemo.collections.asr.parts.preprocessing.segment import AudioSegment | |
| from nemo.collections.tts.models import ssl_tts | |
| from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir | |
| from nemo.core.classes import Dataset | |
| from nemo.utils import logging | |
| class AudioDataset(Dataset): | |
| def __init__( | |
| self, | |
| manifest_paths, | |
| min_duration=0.5, | |
| max_duration=16.0, | |
| pad_multiple=1024, | |
| sample_rate=22050, | |
| sup_data_dir=None, | |
| ): | |
| self.data = [] | |
| for manifest_path in manifest_paths: | |
| with open(manifest_path, "r") as f: | |
| for line in f: | |
| record = json.loads(line) | |
| if record['duration'] < min_duration or record['duration'] > max_duration: | |
| continue | |
| self.data.append(json.loads(line)) | |
| self.base_data_dir = get_base_dir([item["audio_filepath"] for item in self.data]) | |
| if sup_data_dir is not None: | |
| self.sup_data_dir = sup_data_dir | |
| else: | |
| self.sup_data_dir = os.path.join(self.base_data_dir, "sup_data") | |
| if not os.path.exists(self.sup_data_dir): | |
| os.makedirs(self.sup_data_dir) | |
| self.pad_multiple = pad_multiple | |
| self.sample_rate = sample_rate | |
| def __len__(self): | |
| return len(self.data) | |
| def _get_wav_from_filepath(self, audio_filepath): | |
| features = AudioSegment.segment_from_file( | |
| audio_filepath, target_sr=self.sample_rate, n_segments=-1, trim=False, | |
| ) | |
| audio_samples = features.samples | |
| audio, audio_length = torch.tensor(audio_samples), torch.tensor(audio_samples.shape[0]).long() | |
| # pad audio to a multiple of self.pad_multiple | |
| if audio.shape[0] % self.pad_multiple != 0: | |
| audio = torch.cat( | |
| [audio, torch.zeros(self.pad_multiple - audio.shape[0] % self.pad_multiple, dtype=torch.float)] | |
| ) | |
| audio_length = torch.tensor(audio.shape[0]).long() | |
| return audio, audio_length | |
| def pad_collate_fn(self, batch): | |
| final_batch = {} | |
| for row in batch: | |
| for key in row: | |
| if key not in final_batch: | |
| final_batch[key] = [] | |
| final_batch[key].append(row[key]) | |
| max_audio_len = max([_audio_len.item() for _audio_len in final_batch["audio_len"]]) | |
| audios_padded = [] | |
| for audio in final_batch["audio"]: | |
| audio_padded = torch.nn.functional.pad(audio, (0, max_audio_len - audio.size(0)), value=0) | |
| audios_padded.append(audio_padded) | |
| final_batch["audio"] = audios_padded | |
| for key in final_batch: | |
| if key not in ["rel_audio_path_as_text_id", "wav_path"]: | |
| final_batch[key] = torch.stack(final_batch[key]) | |
| return final_batch | |
| def __getitem__(self, index): | |
| sample = self.data[index] | |
| rel_audio_path = Path(sample["audio_filepath"]).relative_to(self.base_data_dir).with_suffix("") | |
| rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") | |
| speaker = torch.tensor(sample["speaker"]).long() | |
| audio, audio_length = self._get_wav_from_filepath(sample["audio_filepath"]) | |
| return { | |
| "audio": audio, | |
| "audio_len": audio_length, | |
| "rel_audio_path_as_text_id": rel_audio_path_as_text_id, | |
| "wav_path": sample["audio_filepath"], | |
| "speaker": speaker, | |
| } | |
| def segment_wav(wav, segment_length, segment_hop_size, min_segment_length): | |
| if len(wav) < segment_length: | |
| pad = torch.zeros(segment_length - len(wav)) | |
| segment = torch.cat([wav, pad]) | |
| return [segment] | |
| else: | |
| si = 0 | |
| segments = [] | |
| while si < len(wav) - min_segment_length: | |
| segment = wav[si : si + segment_length] | |
| if len(segment) < segment_length: | |
| pad = torch.zeros(segment_length - len(segment)) | |
| segment = torch.cat([segment, pad]) | |
| segments.append(segment) | |
| si += segment_hop_size | |
| return segments | |
| def segment_batch(batch, segment_length=44100, segment_hop_size=22050, min_segment_length=22050): | |
| all_segments = [] | |
| segment_indices = [] | |
| si = 0 | |
| for bidx in range(len(batch['audio'])): | |
| audio = batch['audio'][bidx] | |
| audio_length = batch['audio_len'][bidx] | |
| audio_actual = audio[:audio_length] | |
| audio_segments = segment_wav(audio_actual, segment_length, segment_hop_size, min_segment_length) | |
| all_segments += audio_segments | |
| segment_indices.append((si, si + len(audio_segments) - 1)) | |
| si += len(audio_segments) | |
| return torch.stack(all_segments), segment_indices | |
| def get_mel_spectrogram(fb, wav, stft_params): | |
| EPSILON = 1e-9 | |
| window_fn = torch.hann_window | |
| spec = torch.stft( | |
| input=wav, | |
| n_fft=stft_params['n_fft'], # 1024 | |
| hop_length=stft_params['hop_length'], # 256 | |
| win_length=stft_params['win_length'], # 1024 | |
| window=window_fn(stft_params['win_length'], periodic=False).to(torch.float).to('cuda') if window_fn else None, | |
| return_complex=True, | |
| center=True, | |
| ) | |
| if spec.dtype in [torch.cfloat, torch.cdouble]: | |
| spec = torch.view_as_real(spec) | |
| spec = torch.sqrt(spec.pow(2).sum(-1) + EPSILON) | |
| mel = torch.matmul(fb.to(spec.dtype), spec) | |
| log_mel = torch.log(torch.clamp(mel, min=torch.finfo(mel.dtype).tiny)) | |
| return log_mel | |
| def load_wav(wav_path, sample_rate=22050, pad_multiple=1024): | |
| wav = AudioSegment.segment_from_file(wav_path, target_sr=sample_rate, n_segments=-1, trim=False,).samples | |
| if wav.shape[0] % pad_multiple != 0: | |
| wav = np.concatenate([wav, np.zeros(pad_multiple - wav.shape[0] % pad_multiple)]) | |
| wav = wav[:-1] | |
| return wav | |
| def save_pitch_contour(record): | |
| wav_path = record['wav_path'] | |
| wav_text_id = record['wav_id'] | |
| sup_data_dir = record['sup_data_dir'] | |
| stft_params = record['stft_params'] | |
| wav = load_wav(wav_path, stft_params['sample_rate'], stft_params['pad_multiple']) | |
| pitch_contour_fn = f"pitch_contour_{wav_text_id}.pt" | |
| pitch_contour_fp = os.path.join(sup_data_dir, pitch_contour_fn) | |
| f0, _, _ = librosa.pyin( | |
| wav, | |
| fmin=librosa.note_to_hz('C2'), | |
| fmax=stft_params['yin_fmax'], | |
| frame_length=stft_params['win_length'], | |
| hop_length=stft_params['hop_length'], | |
| sr=stft_params['sample_rate'], | |
| center=True, | |
| fill_na=0.0, | |
| ) | |
| pitch_contour = torch.tensor(f0, dtype=torch.float32) | |
| torch.save(pitch_contour, pitch_contour_fp) | |
| logging.info("saved {}".format(pitch_contour_fp)) | |
| return pitch_contour | |
| def compute_pitch_stats(records): | |
| def _is_valid_pitch(pitch_mean, pitch_std): | |
| c1 = pitch_mean > 0 and pitch_mean < 1000 | |
| c2 = pitch_std > 0 and pitch_std < 1000 | |
| return c1 and c2 | |
| speaker_wise_pitch_contours = {} | |
| for item in records: | |
| wav_id = item['wav_id'] | |
| speaker = item['speaker'] | |
| sup_data_dir = item['sup_data_dir'] | |
| pitch_contour_fn = f"pitch_contour_{wav_id}.pt" | |
| pitch_contour_fp = os.path.join(sup_data_dir, pitch_contour_fn) | |
| if speaker not in speaker_wise_pitch_contours: | |
| speaker_wise_pitch_contours[speaker] = [] | |
| speaker_wise_pitch_contours[speaker].append(pitch_contour_fp) | |
| speaker_pitch_stats = {} | |
| for speaker in speaker_wise_pitch_contours: | |
| non_zero_pc = [] | |
| for pitch_contour_fp in speaker_wise_pitch_contours[speaker][:50]: | |
| pitch_contour = torch.load(pitch_contour_fp) | |
| pitch_contour_nonzero = pitch_contour[pitch_contour != 0] | |
| if len(pitch_contour_nonzero) > 0: | |
| non_zero_pc.append(pitch_contour_nonzero) | |
| if len(non_zero_pc) > 0: | |
| non_zero_pc = torch.cat(non_zero_pc) | |
| pitch_mean = non_zero_pc.mean().item() | |
| pitch_std = non_zero_pc.std().item() | |
| valid = True | |
| if not _is_valid_pitch(pitch_mean, pitch_std): | |
| logging.warning("invalid pitch: {}".format(speaker)) | |
| pitch_mean = 212.0 | |
| pitch_std = 70.0 | |
| valid = "False" | |
| else: | |
| logging.warning("could not find pitch contour for speaker {}".format(speaker)) | |
| valid = "False" | |
| pitch_mean = 212.0 | |
| pitch_std = 70.0 | |
| speaker_pitch_stats[speaker] = {"pitch_mean": pitch_mean, "pitch_std": pitch_std, "valid": valid} | |
| with open(os.path.join(sup_data_dir, "speaker_pitch_stats.json"), "w") as f: | |
| json.dump(speaker_pitch_stats, f) | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Evaluate the model') | |
| parser.add_argument( | |
| '--ssl_model_ckpt_path', type=str, required=True, | |
| ) | |
| parser.add_argument('--manifest_paths', type=str, required=True) | |
| parser.add_argument('--sup_data_dir', type=str, default=None) | |
| parser.add_argument('--batch_size', type=int, default=32) | |
| parser.add_argument('--ssl_content_emb_type', type=str, default="embedding_and_probs") | |
| parser.add_argument('--use_unique_tokens', type=int, default=1) | |
| parser.add_argument('--num_workers', type=int, default=8) | |
| parser.add_argument('--pool_workers', type=int, default=30) | |
| parser.add_argument('--compute_pitch_contours', type=int, default=1) | |
| parser.add_argument('--num_pitch_per_speaker', type=int, default=None) # saves time. | |
| parser.add_argument('--sample_rate', type=int, default=22050) | |
| parser.add_argument('--pad_multiple', type=int, default=1024) | |
| parser.add_argument('--ssl_downsampling_factor', type=int, default=4) | |
| parser.add_argument('--stft_n_fft', type=int, default=1024) | |
| parser.add_argument('--stft_hop_length', type=int, default=256) | |
| parser.add_argument('--stft_win_length', type=int, default=1024) | |
| parser.add_argument('--stft_n_mel', type=int, default=80) | |
| parser.add_argument('--stft_fmin', type=int, default=0) | |
| parser.add_argument('--stft_fmax', type=int, default=8000) | |
| parser.add_argument('--yin_fmax', type=int, default=500) | |
| parser.add_argument('--segment_length', type=int, default=44100) | |
| parser.add_argument('--segment_hop_size', type=int, default=22050) | |
| parser.add_argument('--min_segment_length', type=int, default=22050) | |
| args = parser.parse_args() | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| manifest_paths = args.manifest_paths.split(",") | |
| ssl_model_ckpt_path = args.ssl_model_ckpt_path | |
| dataset = AudioDataset( | |
| manifest_paths, pad_multiple=args.pad_multiple, sample_rate=args.sample_rate, sup_data_dir=args.sup_data_dir | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| collate_fn=dataset.pad_collate_fn, | |
| num_workers=args.num_workers, | |
| ) | |
| ssl_model = ssl_tts.SSLDisentangler.load_from_checkpoint(ssl_model_ckpt_path, strict=False) | |
| with open_dict(ssl_model.cfg): | |
| ssl_model.cfg.preprocessor.exact_pad = True | |
| ssl_model.preprocessor = hydra.utils.instantiate(ssl_model.cfg.preprocessor) | |
| ssl_model.preprocessor_disentangler = ssl_model.preprocessor | |
| ssl_model.eval() | |
| ssl_model.to(device) | |
| sample_rate = args.sample_rate | |
| stft_params = { | |
| "n_fft": args.stft_n_fft, | |
| "hop_length": args.stft_hop_length, | |
| "win_length": args.stft_win_length, | |
| "n_mel": args.stft_n_mel, | |
| "sample_rate": sample_rate, | |
| "pad_multiple": args.pad_multiple, | |
| "fmin": args.stft_fmin, | |
| "fmax": args.stft_fmax, | |
| "yin_fmax": args.yin_fmax, | |
| } | |
| fb = ( | |
| torch.tensor( | |
| librosa.filters.mel( | |
| sr=sample_rate, | |
| n_fft=stft_params['n_fft'], | |
| n_mels=stft_params['n_mel'], | |
| fmin=stft_params['fmin'], | |
| fmax=stft_params['fmax'], | |
| ), | |
| dtype=torch.float, | |
| ) | |
| .unsqueeze(0) | |
| .to(device) | |
| ) | |
| st = time.time() | |
| bidx = 0 | |
| wav_and_id_list = [] | |
| for batch in tqdm(dataloader): | |
| bidx += 1 | |
| with torch.no_grad(): | |
| ( | |
| _, | |
| _, | |
| batch_content_embedding, | |
| batch_content_log_probs, | |
| batch_encoded_len, | |
| ) = ssl_model.forward_for_export( | |
| input_signal=batch['audio'].to(device), | |
| input_signal_length=batch['audio_len'].to(device), | |
| normalize_content=True, | |
| ) | |
| batch_mel_specs = get_mel_spectrogram(fb, batch['audio'][:, :-1].to(device), stft_params) | |
| audio_segmented, segment_indices = segment_batch( | |
| batch, args.segment_length, args.segment_hop_size, args.min_segment_length | |
| ) | |
| audio_seg_len = torch.tensor([len(segment) for segment in audio_segmented]).to(device).long() | |
| _, batch_speaker_embeddings, _, _, _ = ssl_model.forward_for_export( | |
| input_signal=audio_segmented.to(device), input_signal_length=audio_seg_len, normalize_content=True, | |
| ) | |
| for idx in range(batch['audio'].shape[0]): | |
| _speaker = batch['speaker'][idx].item() | |
| wav_path = batch['wav_path'][idx] | |
| wav_id = batch['rel_audio_path_as_text_id'][idx] | |
| wav_and_id_list.append((wav_path, wav_id, _speaker)) | |
| content_embedding = batch_content_embedding[idx].detach() | |
| content_log_probs = batch_content_log_probs[:, idx, :].detach() # (content lob prob is (t, b, c)) | |
| encoded_len = batch_encoded_len[idx].detach() | |
| content_embedding = content_embedding[: encoded_len.item()] | |
| content_embedding = content_embedding.t() | |
| content_log_probs = content_log_probs[: encoded_len.item()] | |
| content_log_probs = content_log_probs.t() | |
| content_probs = torch.exp(content_log_probs) | |
| duration = torch.ones(content_embedding.shape[1]) * args.ssl_downsampling_factor | |
| bsi_start = segment_indices[idx][0] | |
| bsi_end = segment_indices[idx][1] | |
| speaker_embedding = torch.mean(batch_speaker_embeddings[bsi_start : bsi_end + 1], dim=0) | |
| l2_norm = torch.norm(speaker_embedding, p=2) | |
| speaker_embedding = speaker_embedding / l2_norm | |
| if args.ssl_content_emb_type == "probs": | |
| # content embedding is only character probabilities | |
| final_content_embedding = content_probs | |
| elif args.ssl_content_emb_type == "embedding": | |
| # content embedding is only output of content head of SSL backbone | |
| final_content_embedding = content_embedding | |
| elif args.ssl_content_emb_type == "log_probs": | |
| # content embedding is only log of character probabilities | |
| final_content_embedding = content_log_probs | |
| elif args.ssl_content_emb_type == "embedding_and_probs": | |
| # content embedding is the concatenation of character probabilities and output of content head of SSL backbone | |
| final_content_embedding = torch.cat([content_embedding, content_probs], dim=0) | |
| if args.use_unique_tokens == 1: | |
| # group content embeddings with same predicted token (by averaging) and add the durations of the grouped embeddings | |
| # Eg. By default each content embedding corresponds to 4 frames of spectrogram (ssl_downsampling_factor) | |
| # If we group 3 content embeddings, the duration of the grouped embedding will be 12 frames. | |
| # This is useful for adapting the duration during inference based on the speaker. | |
| token_predictions = torch.argmax(content_probs, dim=0) | |
| content_buffer = [final_content_embedding[:, 0]] | |
| unique_content_embeddings = [] | |
| unique_tokens = [] | |
| durations = [] | |
| for _t in range(1, final_content_embedding.shape[1]): | |
| if token_predictions[_t] == token_predictions[_t - 1]: | |
| content_buffer.append(final_content_embedding[:, _t]) | |
| else: | |
| durations.append(len(content_buffer) * args.ssl_downsampling_factor) | |
| unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) | |
| content_buffer = [final_content_embedding[:, _t]] | |
| unique_tokens.append(token_predictions[_t].item()) | |
| if len(content_buffer) > 0: | |
| durations.append(len(content_buffer) * args.ssl_downsampling_factor) | |
| unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) | |
| unique_tokens.append(token_predictions[_t].item()) | |
| unique_content_embedding = torch.stack(unique_content_embeddings) | |
| final_content_embedding = unique_content_embedding.t() | |
| duration = torch.tensor(durations).float() | |
| mel_len = int(batch['audio_len'][idx].item() / stft_params['hop_length']) | |
| item_mel = batch_mel_specs[idx][:, :mel_len] | |
| wav_text_id = batch["rel_audio_path_as_text_id"][idx] | |
| content_emb_fn = f"{args.ssl_content_emb_type}_content_embedding_{wav_text_id}.pt" | |
| speaker_emb_fn = f"speaker_embedding_{wav_text_id}.pt" | |
| duration_fn = f"duration_embedding_{wav_text_id}.pt" # embedding just for namesake | |
| content_emb_fp = os.path.join(dataset.sup_data_dir, content_emb_fn) | |
| speaker_emb_fp = os.path.join(dataset.sup_data_dir, speaker_emb_fn) | |
| duration_fp = os.path.join(dataset.sup_data_dir, duration_fn) | |
| mel_spec_fn = f"mel_spec_{wav_text_id}.pt" | |
| mel_spec_fp = os.path.join(dataset.sup_data_dir, mel_spec_fn) | |
| torch.save(item_mel.cpu(), mel_spec_fp) | |
| torch.save(final_content_embedding.cpu(), content_emb_fp) | |
| torch.save(speaker_embedding.cpu(), speaker_emb_fp) | |
| torch.save(duration.cpu(), duration_fp) | |
| et = time.time() | |
| logging.info( | |
| "Processed Batch {} of {} | Time per batch: {:.4f} s".format( | |
| bidx + 1, len(dataloader), (et - st) / bidx | |
| ) | |
| ) | |
| if args.compute_pitch_contours == 1: | |
| speaker_wise_records = {} | |
| for row in wav_and_id_list: | |
| wav_path, wav_id, speaker = row | |
| if speaker not in speaker_wise_records: | |
| speaker_wise_records[speaker] = [] | |
| speaker_wise_records[speaker].append( | |
| { | |
| "wav_path": wav_path, | |
| "wav_id": wav_id, | |
| "sup_data_dir": dataset.sup_data_dir, | |
| "stft_params": stft_params, | |
| "speaker": speaker, | |
| } | |
| ) | |
| filtered_records = [] | |
| for speaker in speaker_wise_records: | |
| if args.num_pitch_per_speaker is not None: | |
| filtered_records += speaker_wise_records[speaker][: args.num_pitch_per_speaker] | |
| else: | |
| filtered_records += speaker_wise_records[speaker] | |
| with Pool(args.pool_workers) as p: | |
| p.map(save_pitch_contour, filtered_records) | |
| compute_pitch_stats(filtered_records) | |
| if __name__ == '__main__': | |
| main() | |