""" From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py Pared down to simplify code. The original file acknowledges: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py https://github.com/CompVis/taming-transformers """ import torch import torch.nn as nn import numpy as np import pytorch_lightning as pl from contextlib import contextmanager from functools import partial from torchmetrics import MeanSquaredError from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding from .ema import LitEma from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d class LatentDiffusion(pl.LightningModule): def __init__(self, model, autoencoder, context_encoder=None, timesteps=1000, beta_schedule="linear", loss_type="l2", use_ema=True, lr=1e-4, lr_warmup=0, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, parameterization="eps", # all assuming fixed variance schedules ): super().__init__() self.model = model self.autoencoder = autoencoder.requires_grad_(False) self.conditional = (context_encoder is not None) self.context_encoder = context_encoder self.lr = lr self.lr_warmup = lr_warmup self.val_loss = MeanSquaredError() assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' self.parameterization = parameterization self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) self.register_schedule( beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s ) self.loss_type = loss_type def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s ) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def apply_model(self, x_noisy, t, cond=None, return_ids=False): if self.conditional: cond = self.context_encoder(cond) with self.ema_scope(): return self.model(x_noisy, t, context=cond) def q_sample(self, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def get_loss(self, pred, target, mean=True): if self.loss_type == 'l1': loss = (target - pred).abs() if mean: loss = loss.mean() elif self.loss_type == 'l2': if mean: loss = torch.nn.functional.mse_loss(target, pred) else: loss = torch.nn.functional.mse_loss(target, pred, reduction='none') else: raise NotImplementedError("unknown loss type '{loss_type}'") return loss def p_losses(self, x_start, t, noise=None, context=None): if noise is None: noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t, context=context) if self.parameterization == "eps": target = noise elif self.parameterization == "x0": target = x_start else: raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") return self.get_loss(model_out, target, mean=False).mean() def forward(self, x, *args, **kwargs): t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() return self.p_losses(x, t, *args, **kwargs) def shared_step(self, batch): (x,y) = batch y = self.autoencoder.encode(y)[0] context = self.context_encoder(x) if self.conditional else None return self(y, context=context) def training_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("train_loss", loss) return loss @torch.no_grad() def validation_step(self, batch, batch_idx): #x, y = batch #y_pred = self(x) #loss2 = torch.nn.functional.mse_loss(y_pred, y) loss = self.shared_step(batch) with self.ema_scope(): loss_ema = self.shared_step(batch) log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} self.log("val_loss", loss, **log_params) self.log("val_loss_ema", loss, **log_params) #self.log("mean_square_error", loss2, **log_params) def test_step(self, batch, batch_idx): return self.validation_step(batch, batch_idx) def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.5, 0.9), weight_decay=1e-3) reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=3, factor=0.25, verbose=True ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": reduce_lr, "monitor": "val_loss_ema", "frequency": 1, }, } def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, #optimizer_closure, **kwargs ): if self.trainer.global_step < self.lr_warmup: lr_scale = (self.trainer.global_step+1) / self.lr_warmup for pg in optimizer.param_groups: pg['lr'] = lr_scale * self.lr super().optimizer_step( epoch, batch_idx, optimizer, optimizer_idx, #optimizer_closure, **kwargs )