Spaces:
Build error
Build error
| import torch | |
| from diffusers import LCMScheduler | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import * | |
| class Hack_SDPipe_Stepwise(StableDiffusionPipeline): | |
| def _use_lcm(self,use=True,ckpt='"latent-consistency/lcm-lora-sdv1-5"'): | |
| if use: | |
| self.use_lcm = True | |
| adapter_id = ckpt | |
| self.scheduler = LCMScheduler.from_config(self.scheduler.config) | |
| # load and fuse lcm lora | |
| self._guidance_scale = 0.0 | |
| self.load_lora_weights(adapter_id) | |
| self.fuse_lora() | |
| else: | |
| self.use_lcm = False | |
| self._guidance_scale = 7.5 | |
| def re_init(self,num_inference_steps=50): | |
| # hyper-parameters | |
| eta = 0.0 | |
| timesteps = None | |
| generator = None | |
| self._clip_skip = None | |
| self._interrupt = False | |
| self._guidance_rescale = 0.0 | |
| self.added_cond_kwargs = None | |
| self._cross_attention_kwargs = None | |
| self._do_classifier_free_guidance = self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None | |
| # 2. Define call parameters | |
| batch_size = 1 | |
| device = self._execution_device | |
| # 4. Prepare timesteps | |
| self.timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) | |
| # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| self.extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 6.2 Optionally get Guidance Scale Embedding | |
| self.timestep_cond = None | |
| if self.unet.config.time_cond_proj_dim is not None: | |
| guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * 1) | |
| self.timestep_cond = self.get_guidance_scale_embedding(guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim).to(device=device) | |
| def _encode_text_prompt(self, | |
| prompt, | |
| negative_prompt='fake,ugly,unreal'): | |
| # 3. Encode input prompt | |
| lora_scale = (self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None) | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt( | |
| prompt, | |
| self._execution_device, | |
| 1, | |
| self.do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| lora_scale=lora_scale, | |
| clip_skip=self.clip_skip, | |
| ) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| return prompt_embeds | |
| def _step_noise(self, | |
| latents, | |
| time_step, | |
| prompt_embeds): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_step) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| time_step, | |
| encoder_hidden_states=prompt_embeds, | |
| timestep_cond=self.timestep_cond, | |
| cross_attention_kwargs=self.cross_attention_kwargs, | |
| added_cond_kwargs=self.added_cond_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | |
| return noise_pred | |
| # @torch.no_grad() | |
| def _encode(self, input): | |
| ''' | |
| # single condition encoding | |
| input: B3HW | |
| return: B4H'W' | |
| if low-vram: vae on cpu, input should also on cpu | |
| ''' | |
| h = self.vae.encoder(input) | |
| moments = self.vae.quant_conv(h) | |
| mean, logvar = torch.chunk(moments, 2, dim=1) | |
| # scale latent | |
| latent = mean * self.vae.config.scaling_factor | |
| return latent | |
| def _decode(self, latent): | |
| ''' | |
| single target decoding | |
| input: B4H'W' | |
| return: B3HW | |
| ''' | |
| # scale latent | |
| latent = latent / self.vae.config.scaling_factor | |
| # decode | |
| z = self.vae.post_quant_conv(latent) | |
| output = self.vae.decoder(z) | |
| return output | |
| def _solve_x0_full_step(self, latents, noise_pred, t): | |
| self.alpha_t = torch.sqrt(self.scheduler.alphas_cumprod).to(t.device) | |
| self.sigma_t = torch.sqrt(1-self.scheduler.alphas_cumprod).to(t.device) | |
| a_t, s_t = self.alpha_t[t], self.sigma_t[t] | |
| x0_latents = (latents - s_t * noise_pred) / a_t | |
| x0 = self._decode(x0_latents) | |
| return x0_latents, x0 | |
| def _solve_x0(self, latents, noise_pred, t): | |
| x0_latents = self.scheduler.step(noise_pred, t.squeeze(), latents) | |
| # note here must be a fake denoise | |
| self.scheduler._step_index-=1 | |
| # results | |
| x0_latents = x0_latents.denoised if self.use_lcm else x0_latents.pred_original_sample | |
| x0 = self._decode(x0_latents) | |
| return x0_latents, x0 | |
| def _step_denoise(self, latents, noise_pred, t): | |
| latents = self.scheduler.step(noise_pred, t.squeeze(), latents).prev_sample | |
| return latents | |
| def xt_x0_noise( | |
| self, | |
| xt_latents: torch.Tensor, | |
| x0_latents: torch.Tensor, | |
| timesteps: torch.IntTensor, | |
| ) -> torch.Tensor: | |
| # Make sure alphas_cumprod and timestep have same device and dtype as original_samples | |
| # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement | |
| # for the subsequent add_noise calls | |
| alphas_cumprod = self.scheduler.alphas_cumprod.to(dtype=xt_latents.dtype,device=xt_latents.device) | |
| timesteps = timesteps.to(xt_latents.device) | |
| sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
| while len(sqrt_alpha_prod.shape) < len(xt_latents.shape): | |
| sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
| sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
| while len(sqrt_one_minus_alpha_prod.shape) < len(xt_latents.shape): | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
| noise = (xt_latents - sqrt_alpha_prod * x0_latents) / sqrt_one_minus_alpha_prod | |
| return noise | |
| def _solve_noise_given_x0_latent(self, latents, x0_latents, t): | |
| noise = self.xt_x0_noise(latents,x0_latents,t) | |
| # -------------------- noise for supervision ----------------- | |
| if self.scheduler.config.prediction_type == "epsilon": | |
| noise = noise | |
| elif self.scheduler.config.prediction_type == "v_prediction": | |
| noise = self.scheduler.get_velocity(x0_latents, noise, t) | |
| # ------------------------------------------------------------ | |
| return noise | |