Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from . import models | |
| def get_name_and_params(base): | |
| name = getattr(base, 'name') | |
| params = getattr(base, 'params') or {} | |
| return name, params | |
| def get_transform(base, transform, mode=None): | |
| if not base: return None | |
| transform = getattr(base, transform) | |
| if not transform: return None | |
| name, params = get_name_and_params(transform) | |
| if mode: | |
| params.update({'mode': mode}) | |
| return getattr(data.transforms, name)(**params) | |
| def build_transforms(cfg, mode): | |
| # 1-Resize | |
| resizer = get_transform(cfg.transform, 'resize') | |
| # 2-(Optional) Data augmentation | |
| augmenter = None | |
| if mode == "train": | |
| augmenter = get_transform(cfg.transform, 'augment') | |
| # 3-(Optional) Crop | |
| cropper = get_transform(cfg.transform, 'crop', mode=mode) | |
| # 4-Preprocess | |
| preprocessor = get_transform(cfg.transform, 'preprocess') | |
| return { | |
| 'resize': resizer, | |
| 'augment': augmenter, | |
| 'crop': cropper, | |
| 'preprocess': preprocessor | |
| } | |
| def build_dataset(cfg, data_info, mode): | |
| dataset_class = getattr(data.datasets, cfg.data.dataset.name) | |
| dataset_params = cfg.data.dataset.params | |
| dataset_params.test_mode = mode != 'train' | |
| dataset_params = dict(dataset_params) | |
| if "FeatureDataset" not in cfg.data.dataset.name: | |
| transforms = build_transforms(cfg, mode) | |
| dataset_params.update(transforms) | |
| dataset_params.update(data_info) | |
| return dataset_class(**dataset_params) | |
| def build_dataloader(cfg, dataset, mode): | |
| def worker_init_fn(worker_id): | |
| np.random.seed(np.random.get_state()[1][0] + worker_id) | |
| dataloader_params = {} | |
| dataloader_params['num_workers'] = cfg.data.num_workers | |
| dataloader_params['drop_last'] = mode == 'train' | |
| dataloader_params['shuffle'] = mode == 'train' | |
| dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True) | |
| if mode in ('train', 'valid'): | |
| if mode == "train": | |
| dataloader_params['batch_size'] = cfg.train.batch_size | |
| elif mode == "valid": | |
| dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size | |
| sampler = None | |
| if cfg.data.get("sampler") and mode == 'train': | |
| name, params = get_name_and_params(cfg.data.sampler) | |
| sampler = getattr(data.samplers, name)(dataset, **params) | |
| if sampler: | |
| dataloader_params['shuffle'] = False | |
| if cfg.strategy == 'ddp': | |
| sampler = data.samplers.DistributedSamplerWrapper(sampler) | |
| dataloader_params['sampler'] = sampler | |
| print(f'Using sampler {sampler} for training ...') | |
| elif cfg.strategy == 'ddp': | |
| dataloader_params["shuffle"] = False | |
| dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train") | |
| else: | |
| assert cfg.strategy != "ddp", "DDP currently not supported for inference" | |
| dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size | |
| loader = DataLoader(dataset, | |
| **dataloader_params, | |
| worker_init_fn=worker_init_fn) | |
| return loader | |
| def build_model(cfg): | |
| name, params = get_name_and_params(cfg.model) | |
| if cfg.model.params.get("cnn_params", None): | |
| cnn_params = cfg.model.params.cnn_params | |
| if cnn_params.get("load_pretrained_backbone", None): | |
| if "foldx" in cnn_params.load_pretrained_backbone: | |
| cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\ | |
| replace("foldx", f"fold{cfg.data.outer_fold}") | |
| print(f'Creating model <{name}> ...') | |
| model = getattr(models.engine, name)(**params) | |
| if 'backbone' in cfg.model.params: | |
| print(f' Using backbone <{cfg.model.params.backbone}> ...') | |
| if 'pretrained' in cfg.model.params: | |
| print(f' Pretrained : {cfg.model.params.pretrained}') | |
| if "load_pretrained" in cfg.model: | |
| import re | |
| if "foldx" in cfg.model.load_pretrained: | |
| cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}") | |
| print(f" Loading pretrained checkpoint from {cfg.model.load_pretrained}") | |
| weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict'] | |
| weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k} | |
| model.load_state_dict(weights) | |
| return model | |
| def build_loss(cfg): | |
| name, params = get_name_and_params(cfg.loss) | |
| print(f'Using loss function <{name}> ...') | |
| params = dict(params) | |
| if "pos_weight" in params: | |
| params["pos_weight"] = torch.tensor(params["pos_weight"]) | |
| criterion = getattr(losses, name)(**params) | |
| return criterion | |
| def build_scheduler(cfg, optimizer): | |
| # Some schedulers will require manipulation of config params | |
| # My specifications were to make it more intuitive for me | |
| name, params = get_name_and_params(cfg.scheduler) | |
| print(f'Using learning rate schedule <{name}> ...') | |
| if name == 'CosineAnnealingLR': | |
| # eta_min <-> final_lr | |
| # Set T_max as 100000 ... this is changed in on_train_start() method | |
| # of the LightningModule task | |
| params = { | |
| 'T_max': 100000, | |
| 'eta_min': max(params.final_lr, 1.0e-8) | |
| } | |
| if name in ('OneCycleLR', 'CustomOneCycleLR'): | |
| # Use learning rate from optimizer parameters as initial learning rate | |
| lr_0 = cfg.optimizer.params.lr | |
| lr_1 = params.max_lr | |
| lr_2 = params.final_lr | |
| # lr_0 -> lr_1 -> lr_2 | |
| pct_start = params.pct_start | |
| params = {} | |
| params['steps_per_epoch'] = 100000 # see above- will fix in task | |
| params['epochs'] = cfg.train.num_epochs | |
| params['max_lr'] = lr_1 | |
| params['pct_start'] = pct_start | |
| params['div_factor'] = lr_1 / lr_0 # max/init | |
| params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) # init/final | |
| scheduler = getattr(optim, name)(optimizer=optimizer, **params) | |
| # Some schedulers might need more manipulation after instantiation | |
| if name in ('OneCycleLR', 'CustomOneCycleLR'): | |
| scheduler.pct_start = params['pct_start'] | |
| # Set update frequency | |
| if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'): | |
| scheduler.update_frequency = 'on_batch' | |
| elif name in ('ReduceLROnPlateau'): | |
| scheduler.update_frequency = 'on_valid' | |
| else: | |
| scheduler.update_frequency = 'on_epoch' | |
| return scheduler | |
| def build_optimizer(cfg, parameters): | |
| name, params = get_name_and_params(cfg.optimizer) | |
| print(f'Using optimizer <{name}> ...') | |
| optimizer = getattr(optim, name)(parameters, **params) | |
| return optimizer | |
| def build_task(cfg, model): | |
| name, params = get_name_and_params(cfg.task) | |
| print(f'Building task <{name}> ...') | |
| return getattr(tasks, name)(cfg, model, **params) | |