Spaces:
Runtime error
Runtime error
| from base64 import b64encode | |
| import torch | |
| import numpy as np | |
| from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel | |
| from huggingface_hub import notebook_login | |
| import torch.nn.functional as F | |
| # For video display: | |
| from IPython.display import HTML | |
| from matplotlib import pyplot as plt | |
| from pathlib import Path | |
| from PIL import Image | |
| from torch import autocast | |
| from torchvision import transforms as tfms | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPTextModel, CLIPTokenizer, logging | |
| import os | |
| from device import torch_device,vae,text_encoder,unet,tokenizer,scheduler,token_emb_layer,pos_emb_layer,position_embeddings | |
| # Supress some unnecessary warnings when loading the CLIPTextModel | |
| logging.set_verbosity_error() | |
| def pil_to_latent(input_im): | |
| # Single image -> single latent in a batch (so size 1, 4, 64, 64) | |
| with torch.no_grad(): | |
| latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling | |
| return 0.18215 * latent.latent_dist.sample() | |
| def latents_to_pil(latents): | |
| # batch of latents -> list of images | |
| latents = (1 / 0.18215) * latents | |
| with torch.no_grad(): | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| images = (image * 255).round().astype("uint8") | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| def set_timesteps(scheduler, num_inference_steps): | |
| scheduler.set_timesteps(num_inference_steps) | |
| scheduler.timesteps = scheduler.timesteps.to(torch.float32) | |
| def orange_loss(image): | |
| # Convert the image to a NumPy array | |
| #image = image.float() # Convert to a more standard data type (float32) | |
| #image_np = image.detach().cpu().numpy() # Use .detach() and .cpu() to ensure compatibility | |
| # Extract the orange channel (e.g., Red and Green channels) | |
| orange_channel = image[:, 0, :, :] + image[:, 1, :, :] | |
| # Calculate the mean intensity of the orange channel | |
| #orange_mean = np.mean(orange_channel) | |
| # Define the target mean intensity you desire | |
| target_mean = 0.8 # Replace with your desired mean intensity | |
| # Calculate the loss based on the squared difference from the target | |
| loss = torch.abs(orange_channel- target_mean).mean() | |
| # Convert the loss to a PyTorch tensor | |
| #loss = torch.tensor(loss, dtype=image.dtype) | |
| return loss |