Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import pathlib | |
| import shutil | |
| import sys | |
| from typing import Dict | |
| import matplotlib | |
| import utils | |
| from utils.text_encoder import TokenTextEncoder | |
| matplotlib.use('Agg') | |
| import torch.utils.data | |
| from torchmetrics import Metric, MeanMetric | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only | |
| from basics.base_module import CategorizedModule | |
| from utils.hparams import hparams | |
| from utils.training_utils import ( | |
| DsModelCheckpoint, DsTQDMProgressBar, | |
| DsBatchSampler, DsTensorBoardLogger, | |
| get_latest_checkpoint_path, get_strategy | |
| ) | |
| from utils.phoneme_utils import locate_dictionary, build_phoneme_list | |
| torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) | |
| log_format = '%(asctime)s %(message)s' | |
| logging.basicConfig(stream=sys.stdout, level=logging.INFO, | |
| format=log_format, datefmt='%m/%d %I:%M:%S %p') | |
| class BaseTask(pl.LightningModule): | |
| """ | |
| Base class for training tasks. | |
| 1. *load_ckpt*: | |
| load checkpoint; | |
| 2. *training_step*: | |
| record and log the loss; | |
| 3. *optimizer_step*: | |
| run backwards step; | |
| 4. *start*: | |
| load training configs, backup code, log to tensorboard, start training; | |
| 5. *configure_ddp* and *init_ddp_connection*: | |
| start parallel training. | |
| Subclasses should define: | |
| 1. *build_model*, *build_optimizer*, *build_scheduler*: | |
| how to build the model, the optimizer and the training scheduler; | |
| 2. *_training_step*: | |
| one training step of the model; | |
| 3. *on_validation_end* and *_on_validation_end*: | |
| postprocess the validation output. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.max_batch_frames = hparams['max_batch_frames'] | |
| self.max_batch_size = hparams['max_batch_size'] | |
| self.max_val_batch_frames = hparams['max_val_batch_frames'] | |
| if self.max_val_batch_frames == -1: | |
| hparams['max_val_batch_frames'] = self.max_val_batch_frames = self.max_batch_frames | |
| self.max_val_batch_size = hparams['max_val_batch_size'] | |
| if self.max_val_batch_size == -1: | |
| hparams['max_val_batch_size'] = self.max_val_batch_size = self.max_batch_size | |
| self.training_sampler = None | |
| self.skip_immediate_validation = False | |
| self.skip_immediate_ckpt_save = False | |
| self.phone_encoder = self.build_phone_encoder() | |
| self.build_model() | |
| self.valid_losses: Dict[str, Metric] = {} | |
| self.valid_metrics: Dict[str, Metric] = {} | |
| def _finish_init(self): | |
| self.register_validation_loss('total_loss') | |
| self.build_losses_and_metrics() | |
| assert len(self.valid_losses) > 0, "No validation loss registered. Please check your configuration file." | |
| ########### | |
| # Training, validation and testing | |
| ########### | |
| def setup(self, stage): | |
| self.train_dataset = self.dataset_cls('train') | |
| self.valid_dataset = self.dataset_cls('valid') | |
| self.num_replicas = (self.trainer.distributed_sampler_kwargs or {}).get('num_replicas', 1) | |
| def get_need_freeze_state_dict_key(self, model_state_dict) -> list: | |
| key_list = [] | |
| for i in hparams['frozen_params']: | |
| for j in model_state_dict: | |
| if j.startswith(i): | |
| key_list.append(j) | |
| return list(set(key_list)) | |
| def freeze_params(self) -> None: | |
| model_state_dict = self.state_dict().keys() | |
| freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) | |
| for i in freeze_key: | |
| params=self.get_parameter(i) | |
| params.requires_grad = False | |
| def unfreeze_all_params(self) -> None: | |
| for i in self.model.parameters(): | |
| i.requires_grad = True | |
| def load_finetune_ckpt( | |
| self, state_dict | |
| ) -> None: | |
| adapt_shapes = hparams['finetune_strict_shapes'] | |
| if not adapt_shapes: | |
| cur_model_state_dict = self.state_dict() | |
| unmatched_keys = [] | |
| for key, param in state_dict.items(): | |
| if key in cur_model_state_dict: | |
| new_param = cur_model_state_dict[key] | |
| if new_param.shape != param.shape: | |
| unmatched_keys.append(key) | |
| print('| Unmatched keys: ', key, new_param.shape, param.shape) | |
| for key in unmatched_keys: | |
| del state_dict[key] | |
| self.load_state_dict(state_dict, strict=False) | |
| def load_pre_train_model(self): | |
| pre_train_ckpt_path = hparams['finetune_ckpt_path'] | |
| blacklist = hparams['finetune_ignored_params'] | |
| # whitelist=hparams['pre_train_whitelist'] | |
| if blacklist is None: | |
| blacklist = [] | |
| # if whitelist is None: | |
| # raise RuntimeError("") | |
| if pre_train_ckpt_path is not None: | |
| ckpt = torch.load(pre_train_ckpt_path) | |
| # if ckpt.get('category') is None: | |
| # raise RuntimeError("") | |
| if isinstance(self.model, CategorizedModule): | |
| self.model.check_category(ckpt.get('category')) | |
| state_dict = {} | |
| for i in ckpt['state_dict']: | |
| # if 'diffusion' in i: | |
| # if i in rrrr: | |
| # continue | |
| skip = False | |
| for b in blacklist: | |
| if i.startswith(b): | |
| skip = True | |
| break | |
| if skip: | |
| continue | |
| state_dict[i] = ckpt['state_dict'][i] | |
| print(i) | |
| return state_dict | |
| else: | |
| raise RuntimeError("") | |
| def build_phone_encoder(): | |
| phone_list = build_phoneme_list() | |
| return TokenTextEncoder(vocab_list=phone_list) | |
| def _build_model(self): | |
| raise NotImplementedError() | |
| def build_model(self): | |
| self.model = self._build_model() | |
| # utils.load_warp(self) | |
| self.unfreeze_all_params() | |
| if hparams['freezing_enabled']: | |
| self.freeze_params() | |
| if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: | |
| self.load_finetune_ckpt(self.load_pre_train_model()) | |
| self.print_arch() | |
| def print_arch(self): | |
| utils.print_arch(self.model) | |
| def build_losses_and_metrics(self): | |
| raise NotImplementedError() | |
| def register_validation_metric(self, name: str, metric: Metric): | |
| assert isinstance(metric, Metric) | |
| self.valid_metrics[name] = metric | |
| def register_validation_loss(self, name: str, Aggregator: Metric = MeanMetric): | |
| assert issubclass(Aggregator, Metric) | |
| self.valid_losses[name] = Aggregator() | |
| def run_model(self, sample, infer=False): | |
| """ | |
| steps: | |
| 1. run the full model | |
| 2. calculate losses if not infer | |
| """ | |
| raise NotImplementedError() | |
| def on_train_epoch_start(self): | |
| if self.training_sampler is not None: | |
| self.training_sampler.set_epoch(self.current_epoch) | |
| def _training_step(self, sample): | |
| """ | |
| :return: total loss: torch.Tensor, loss_log: dict, other_log: dict | |
| """ | |
| losses = self.run_model(sample) | |
| total_loss = sum(losses.values()) | |
| return total_loss, {**losses, 'batch_size': float(sample['size'])} | |
| def training_step(self, sample, batch_idx): | |
| total_loss, log_outputs = self._training_step(sample) | |
| # logs to progress bar | |
| self.log_dict(log_outputs, prog_bar=True, logger=False, on_step=True, on_epoch=False) | |
| self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False) | |
| # logs to tensorboard | |
| if self.global_step % hparams['log_interval'] == 0: | |
| tb_log = {f'training/{k}': v for k, v in log_outputs.items()} | |
| tb_log['training/lr'] = self.lr_schedulers().get_last_lr()[0] | |
| self.logger.log_metrics(tb_log, step=self.global_step) | |
| return total_loss | |
| # def on_before_optimizer_step(self, *args, **kwargs): | |
| # self.log_dict(grad_norm(self, norm_type=2)) | |
| def _on_validation_start(self): | |
| pass | |
| def on_validation_start(self): | |
| if self.skip_immediate_validation: | |
| rank_zero_debug("Skip validation") | |
| return | |
| self._on_validation_start() | |
| for metric in self.valid_losses.values(): | |
| metric.to(self.device) | |
| metric.reset() | |
| for metric in self.valid_metrics.values(): | |
| metric.to(self.device) | |
| metric.reset() | |
| def _validation_step(self, sample, batch_idx): | |
| """ | |
| :param sample: | |
| :param batch_idx: | |
| :return: loss_log: dict, weight: int | |
| """ | |
| raise NotImplementedError() | |
| def validation_step(self, sample, batch_idx): | |
| """ | |
| :param sample: | |
| :param batch_idx: | |
| """ | |
| if self.skip_immediate_validation: | |
| rank_zero_debug("Skip validation") | |
| return | |
| if sample['size'] > 0: | |
| with torch.autocast(self.device.type, enabled=False): | |
| losses, weight = self._validation_step(sample, batch_idx) | |
| losses = { | |
| 'total_loss': sum(losses.values()), | |
| **losses | |
| } | |
| for k, v in losses.items(): | |
| self.valid_losses[k].update(v, weight=weight) | |
| def _on_validation_epoch_end(self): | |
| pass | |
| def on_validation_epoch_end(self): | |
| if self.skip_immediate_validation: | |
| self.skip_immediate_validation = False | |
| self.skip_immediate_ckpt_save = True | |
| return | |
| self._on_validation_epoch_end() | |
| loss_vals = {k: v.compute() for k, v in self.valid_losses.items()} | |
| metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()} | |
| self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) | |
| self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step) | |
| self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step) | |
| # noinspection PyMethodMayBeStatic | |
| def build_scheduler(self, optimizer): | |
| from utils import build_lr_scheduler_from_config | |
| scheduler_args = hparams['lr_scheduler_args'] | |
| assert scheduler_args['scheduler_cls'] != '' | |
| scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) | |
| return scheduler | |
| # noinspection PyMethodMayBeStatic | |
| def build_optimizer(self, model): | |
| from utils import build_object_from_class_name | |
| optimizer_args = hparams['optimizer_args'] | |
| assert optimizer_args['optimizer_cls'] != '' | |
| if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: | |
| optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) | |
| optimizer = build_object_from_class_name( | |
| optimizer_args['optimizer_cls'], | |
| torch.optim.Optimizer, | |
| model.parameters(), | |
| **optimizer_args | |
| ) | |
| return optimizer | |
| def configure_optimizers(self): | |
| optm = self.build_optimizer(self.model) | |
| scheduler = self.build_scheduler(optm) | |
| if scheduler is None: | |
| return optm | |
| return { | |
| "optimizer": optm, | |
| "lr_scheduler": { | |
| "scheduler": scheduler, | |
| "interval": "step", | |
| "frequency": 1 | |
| } | |
| } | |
| def train_dataloader(self): | |
| self.training_sampler = DsBatchSampler( | |
| self.train_dataset, | |
| max_batch_frames=self.max_batch_frames, | |
| max_batch_size=self.max_batch_size, | |
| num_replicas=self.num_replicas, | |
| rank=self.global_rank, | |
| sort_by_similar_size=hparams['sort_by_len'], | |
| size_reversed=True, | |
| required_batch_count_multiple=hparams['accumulate_grad_batches'], | |
| shuffle_sample=True, | |
| shuffle_batch=True | |
| ) | |
| return torch.utils.data.DataLoader( | |
| self.train_dataset, | |
| collate_fn=self.train_dataset.collater, | |
| batch_sampler=self.training_sampler, | |
| num_workers=hparams['ds_workers'], | |
| prefetch_factor=hparams['dataloader_prefetch_factor'], | |
| pin_memory=True, | |
| persistent_workers=True | |
| ) | |
| def val_dataloader(self): | |
| sampler = DsBatchSampler( | |
| self.valid_dataset, | |
| max_batch_frames=self.max_val_batch_frames, | |
| max_batch_size=self.max_val_batch_size, | |
| num_replicas=self.num_replicas, | |
| rank=self.global_rank, | |
| shuffle_sample=False, | |
| shuffle_batch=False, | |
| disallow_empty_batch=False, | |
| pad_batch_assignment=False | |
| ) | |
| return torch.utils.data.DataLoader( | |
| self.valid_dataset, | |
| collate_fn=self.valid_dataset.collater, | |
| batch_sampler=sampler, | |
| num_workers=hparams['ds_workers'], | |
| prefetch_factor=hparams['dataloader_prefetch_factor'], | |
| persistent_workers=True | |
| ) | |
| def test_dataloader(self): | |
| return self.val_dataloader() | |
| def on_test_start(self): | |
| self.on_validation_start() | |
| def test_step(self, sample, batch_idx): | |
| return self.validation_step(sample, batch_idx) | |
| def on_test_end(self): | |
| return self.on_validation_end() | |
| ########### | |
| # Running configuration | |
| ########### | |
| def start(cls): | |
| task = cls() | |
| # if pre_train is not None: | |
| # task.load_state_dict(pre_train,strict=False) | |
| # print("load success-------------------------------------------------------------------") | |
| work_dir = pathlib.Path(hparams['work_dir']) | |
| trainer = pl.Trainer( | |
| accelerator=hparams['pl_trainer_accelerator'], | |
| devices=hparams['pl_trainer_devices'], | |
| num_nodes=hparams['pl_trainer_num_nodes'], | |
| strategy=get_strategy( | |
| hparams['pl_trainer_devices'], | |
| hparams['pl_trainer_num_nodes'], | |
| hparams['pl_trainer_accelerator'], | |
| hparams['pl_trainer_strategy'], | |
| hparams['pl_trainer_precision'], | |
| ), | |
| precision=hparams['pl_trainer_precision'], | |
| callbacks=[ | |
| DsModelCheckpoint( | |
| dirpath=work_dir, | |
| filename='model_ckpt_steps_{step}', | |
| auto_insert_metric_name=False, | |
| monitor='step', | |
| mode='max', | |
| save_last=False, | |
| # every_n_train_steps=hparams['val_check_interval'], | |
| save_top_k=hparams['num_ckpt_keep'], | |
| permanent_ckpt_start=hparams['permanent_ckpt_start'], | |
| permanent_ckpt_interval=hparams['permanent_ckpt_interval'], | |
| verbose=True | |
| ), | |
| # LearningRateMonitor(logging_interval='step'), | |
| DsTQDMProgressBar(), | |
| ], | |
| logger=DsTensorBoardLogger( | |
| save_dir=str(work_dir), | |
| name='lightning_logs', | |
| version='latest' | |
| ), | |
| gradient_clip_val=hparams['clip_grad_norm'], | |
| val_check_interval=hparams['val_check_interval'] * hparams['accumulate_grad_batches'], | |
| # so this is global_steps | |
| check_val_every_n_epoch=None, | |
| log_every_n_steps=1, | |
| max_steps=hparams['max_updates'], | |
| use_distributed_sampler=False, | |
| num_sanity_val_steps=hparams['num_sanity_val_steps'], | |
| accumulate_grad_batches=hparams['accumulate_grad_batches'] | |
| ) | |
| if not hparams['infer']: # train | |
| def train_payload_copy(): | |
| # Copy spk_map.json and dictionary.txt to work dir | |
| binary_dir = pathlib.Path(hparams['binary_data_dir']) | |
| spk_map = work_dir / 'spk_map.json' | |
| spk_map_src = binary_dir / 'spk_map.json' | |
| if not spk_map.exists() and spk_map_src.exists(): | |
| shutil.copy(spk_map_src, spk_map) | |
| print(f'| Copied spk map to {spk_map}.') | |
| dictionary = work_dir / 'dictionary.txt' | |
| dict_src = binary_dir / 'dictionary.txt' | |
| if not dictionary.exists(): | |
| if dict_src.exists(): | |
| shutil.copy(dict_src, dictionary) | |
| else: | |
| shutil.copy(locate_dictionary(), dictionary) | |
| print(f'| Copied dictionary to {dictionary}.') | |
| train_payload_copy() | |
| trainer.fit(task, ckpt_path=get_latest_checkpoint_path(work_dir)) | |
| else: | |
| trainer.test(task) | |
| def on_save_checkpoint(self, checkpoint): | |
| if isinstance(self.model, CategorizedModule): | |
| checkpoint['category'] = self.model.category | |
| checkpoint['trainer_stage'] = self.trainer.state.stage.value | |
| def on_load_checkpoint(self, checkpoint): | |
| from lightning.pytorch.trainer.states import RunningStage | |
| from utils import simulate_lr_scheduler | |
| if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value: | |
| self.skip_immediate_validation = True | |
| optimizer_args = hparams['optimizer_args'] | |
| scheduler_args = hparams['lr_scheduler_args'] | |
| if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: | |
| optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) | |
| if checkpoint.get('optimizer_states', None): | |
| opt_states = checkpoint['optimizer_states'] | |
| assert len(opt_states) == 1 # only support one optimizer | |
| opt_state = opt_states[0] | |
| for param_group in opt_state['param_groups']: | |
| for k, v in optimizer_args.items(): | |
| if k in param_group and param_group[k] != v: | |
| if 'lr_schedulers' in checkpoint and checkpoint['lr_schedulers'] and k == 'lr': | |
| continue | |
| rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}') | |
| param_group[k] = v | |
| if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']: | |
| rank_zero_info( | |
| f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}' | |
| ) | |
| param_group['initial_lr'] = optimizer_args['lr'] | |
| if checkpoint.get('lr_schedulers', None): | |
| assert checkpoint.get('optimizer_states', False) | |
| assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler | |
| checkpoint['lr_schedulers'][0] = simulate_lr_scheduler( | |
| optimizer_args, scheduler_args, | |
| step_count=checkpoint['global_step'], | |
| num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups']) | |
| ) | |
| for param_group, new_lr in zip( | |
| checkpoint['optimizer_states'][0]['param_groups'], | |
| checkpoint['lr_schedulers'][0]['_last_lr'], | |
| ): | |
| if param_group['lr'] != new_lr: | |
| rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') | |
| param_group['lr'] = new_lr | |