Spaces:
Running
on
Zero
Running
on
Zero
| import pickle | |
| from torch.utils.data import Dataset | |
| import cv2 | |
| import argparse | |
| import glob | |
| import random | |
| import logging | |
| import torch | |
| import os | |
| import numpy as np | |
| import PIL | |
| from PIL import Image, ImageDraw | |
| from einops import rearrange | |
| from urllib.parse import urlparse | |
| from diffusers.utils import load_image | |
| import math | |
| # copy from https://github.com/crowsonkb/k-diffusion.git | |
| def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): | |
| """Draws samples from an lognormal distribution.""" | |
| u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7 | |
| return torch.distributions.Normal(loc, scale).icdf(u).exp() | |
| def encode_image(pixel_values, feature_extractor, image_encoder, weight_dtype, accelerator): | |
| # pixel: [-1, 1] | |
| pixel_values = _resize_with_antialiasing(pixel_values, (224, 224)) | |
| # We unnormalize it after resizing. | |
| pixel_values = (pixel_values + 1.0) / 2.0 | |
| # Normalize the image with for CLIP input | |
| pixel_values = feature_extractor( | |
| images=pixel_values, | |
| do_normalize=True, | |
| do_center_crop=False, | |
| do_resize=False, | |
| do_rescale=False, | |
| return_tensors="pt", | |
| ).pixel_values | |
| pixel_values = pixel_values.to( | |
| device=accelerator.device, dtype=weight_dtype) | |
| image_embeddings = image_encoder(pixel_values).image_embeds | |
| return image_embeddings | |
| def get_add_time_ids( | |
| fps, | |
| motion_bucket_id, | |
| noise_aug_strength, | |
| dtype, | |
| batch_size, | |
| unet | |
| ): | |
| add_time_ids = [fps, motion_bucket_id, noise_aug_strength] | |
| passed_add_embed_dim = unet.config.addition_time_embed_dim * \ | |
| len(add_time_ids) | |
| expected_add_embed_dim = unet.add_embedding.linear_1.in_features | |
| if expected_add_embed_dim != passed_add_embed_dim: | |
| raise ValueError( | |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." | |
| ) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype) | |
| add_time_ids = add_time_ids.repeat(batch_size, 1) | |
| return add_time_ids | |
| def find_scale(height, width): | |
| """ | |
| Finds a scale factor such that the number of pixels is less than 500,000 | |
| and the dimensions are rounded down to the nearest multiple of 64. | |
| Args: | |
| height (int): The original height of the image. | |
| width (int): The original width of the image. | |
| Returns: | |
| tuple: The scaled height and width as integers. | |
| """ | |
| max_pixels = 500000 | |
| # Start with no scaling | |
| scale = 1.0 | |
| while True: | |
| # Calculate the scaled dimensions | |
| scaled_height = math.floor((height * scale) / 64) * 64 | |
| scaled_width = math.floor((width * scale) / 64) * 64 | |
| # Check if the scaled dimensions meet the pixel constraint | |
| if scaled_height * scaled_width <= max_pixels: | |
| return scaled_height, scaled_width | |
| # Reduce the scale slightly | |
| scale -= 0.01 | |
| class OutsidePhotosDataset(Dataset): | |
| def __init__(self, data_folder, width=1024, height=576, sample_frames=9): | |
| self.data_folder = data_folder | |
| self.scenes = sorted(glob.glob(os.path.join(data_folder, "*"))) | |
| #get images that end in .JPG,.jpg, .png | |
| self.scenes = [scene for scene in self.scenes if scene.endswith(".JPG") or scene.endswith(".jpg") or scene.endswith(".png") or scene.endswith(".jpeg") or scene.endswith(".JPG")] | |
| #make each scene a tuple anf for each scene, put it 9 times in the tuple - tuple should look like (scene_name, idx (0-8)) | |
| self.scenes = [(scene, idx) for scene in self.scenes for idx in range(9)] | |
| self.num_scenes = len(self.scenes) | |
| self.width = width | |
| self.height = height | |
| self.sample_frames = sample_frames | |
| self.icc_profiles = [None]*self.num_scenes | |
| def __len__(self): | |
| return self.num_scenes | |
| def __getitem__(self, idx): | |
| #get the scene and the index | |
| #create an empty tensor to store the pixel values and place the scene in the tensor (load and resize the image) | |
| scene, focal_stack_num = self.scenes[idx] | |
| with Image.open(scene) as img: | |
| self.icc_profiles[idx] = img.info.get("icc_profile") | |
| icc_profile = img.info.get("icc_profile") | |
| if icc_profile is None: | |
| icc_profile = "none" | |
| original_pixels = torch.from_numpy(np.array(img)).float().permute(2,0,1) | |
| original_pixels = original_pixels / 255 | |
| width, height = img.size | |
| scaled_width, scaled_height = find_scale(width, height) | |
| img_resized = img.resize((scaled_width, scaled_height)) | |
| img_tensor = torch.from_numpy(np.array(img_resized)).float() | |
| img_normalized = img_tensor / 127.5 - 1 | |
| img_normalized = img_normalized.permute(2, 0, 1) | |
| pixels = torch.zeros((self.sample_frames, 3, scaled_height, scaled_width)) | |
| pixels[focal_stack_num] = img_normalized | |
| return {"pixel_values": pixels, "idx": idx//9, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile} | |
| class FocalStackDataset(Dataset): | |
| def __init__(self, data_folder: str, splits_dir, split="train", num_samples=100000, width=640, height=896, sample_frames=9): #4.5 | |
| #800*600 - 480000 | |
| #896*672 - 602112 | |
| """ | |
| Args: | |
| num_samples (int): Number of samples in the dataset. | |
| channels (int): Number of channels, default is 3 for RGB. | |
| """ | |
| self.num_samples = num_samples | |
| self.sample_frames = sample_frames | |
| # Define the path to the folder containing video frames | |
| self.data_folder = data_folder | |
| self.splits_dir = splits_dir | |
| size = "midsize" | |
| # Use glob to find matching folders | |
| # List to store the desired paths | |
| rig_directories = [] | |
| # Walk through the directory | |
| for root, dirs, files in os.walk(data_folder): | |
| # Check if the path matches "downscaled/undistorted/Rig*" | |
| for directory in dirs: | |
| if directory.startswith("RigCenter") and f"{size}/undistorted" in root.replace("\\", "/"): | |
| rig_directory = os.path.join(root, directory) | |
| #check that rig_directory contains all 9 images | |
| if len(glob.glob(os.path.join(rig_directory, "*.jpg"))) == 9: | |
| rig_directories.append(rig_directory) | |
| self.scenes = sorted(rig_directories) #sort the files by name | |
| if split == "train": | |
| #shuffle the scenes | |
| random.shuffle(self.scenes) | |
| self.split = split | |
| debug = False | |
| if debug: | |
| self.scenes = self.scenes[50:60] | |
| elif split == "train": | |
| pkl_file = os.path.join(self.splits_dir, "train_scenes.pkl") | |
| #load the train scenes | |
| with open(pkl_file, "rb") as f: | |
| pkl_scenes = pickle.load(f) | |
| #only get scenes that are found in pkl file | |
| self.scenes = [scene for scene in self.scenes if scene.split('/')[-4] in pkl_scenes] | |
| elif split == "val": | |
| pkl_file = os.path.join(self.splits_dir, "test_scenes.pkl") #use first 10 test scenes for val (just for visualization) | |
| #load the test scenes | |
| with open(pkl_file, "rb") as f: | |
| pkl_scenes = pickle.load(f) | |
| #only get scenes that are found in pkl file | |
| self.scenes = [scene for scene in self.scenes if scene.split('/')[-4] in pkl_scenes] | |
| self.scenes = self.scenes[:10] | |
| else: | |
| pkl_file = os.path.join(self.splits_dir, "test_scenes.pkl") | |
| #load the test scenes | |
| with open(pkl_file, "rb") as f: | |
| pkl_scenes = pickle.load(f) | |
| #only get scenes that are found in pkl file | |
| self.scenes = [scene for scene in self.scenes if scene.split('/')[-4] in pkl_scenes] | |
| if split == "test": | |
| self.scenes = [(scene, idx) for scene in self.scenes for idx in range(self.sample_frames)] | |
| self.num_scenes = len(self.scenes) | |
| max_trdata = 0 | |
| if max_trdata > 0: | |
| self.scenes = self.scenes[:max_trdata] | |
| self.data_store = {} | |
| logging.info(f'Creating {split} dataset with {self.num_scenes} examples') | |
| self.channels = 3 | |
| self.width = width | |
| self.height = height | |
| def __len__(self): | |
| return self.num_scenes | |
| def __getitem__(self, idx): | |
| """ | |
| Args: | |
| idx (int): Index of the sample to return. | |
| Returns: | |
| dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). | |
| """ | |
| # Randomly select a folder (representing a video) from the base folder | |
| if self.split == "test": | |
| chosen_folder, focal_stack_num = self.scenes[idx] | |
| else: | |
| chosen_folder = self.scenes[idx] | |
| frames = os.listdir(chosen_folder) | |
| #get only frames that are jpg | |
| frames = [frame for frame in frames if frame.endswith(".jpg")] | |
| # Sort the frames by name | |
| frames.sort() | |
| #Pad the frames list out | |
| selected_frames = frames[:self.sample_frames] | |
| # Initialize a tensor to store the pixel values | |
| pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width)) | |
| original_pixel_values = torch.empty((self.sample_frames, self.channels, 896, 640)) | |
| # Load and process each frame | |
| for i, frame_name in enumerate(selected_frames): | |
| frame_path = os.path.join(chosen_folder, frame_name) | |
| with Image.open(frame_path) as img: | |
| # Resize the image and convert it to a tensor | |
| img_resized = img.resize((self.width, self.height)) | |
| img_tensor = torch.from_numpy(np.array(img_resized)).float() | |
| original_img_tensor = torch.from_numpy(np.array(img)).float() | |
| # Normalize the image by scaling pixel values to [-1, 1] | |
| img_normalized = img_tensor / 127.5 - 1 | |
| original_img_normalized = original_img_tensor / 127.5 - 1 | |
| # Rearrange channels if necessary | |
| if self.channels == 3: | |
| img_normalized = img_normalized.permute( | |
| 2, 0, 1) # For RGB images | |
| original_img_normalized = original_img_normalized.permute(2, 0, 1) | |
| pixel_values[i] = img_normalized | |
| original_pixel_values[i] = original_img_normalized | |
| if self.sample_frames == 10: #special case for 10 frames where we duplicate the 9th frame (sometimes reduced color artifacts) | |
| pixel_values[9] = pixel_values[8] | |
| original_pixel_values[9] = original_pixel_values[8] | |
| out_dict = {'pixel_values': pixel_values, "idx": idx, "original_pixel_values": original_pixel_values} | |
| if self.split == "test": | |
| out_dict["focal_stack_num"] = focal_stack_num | |
| out_dict["idx"] = idx//9 | |
| return out_dict | |
| # resizing utils | |
| # TODO: clean up later | |
| def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): | |
| h, w = input.shape[-2:] | |
| factors = (h / size[0], w / size[1]) | |
| # First, we have to determine sigma | |
| # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 | |
| sigmas = ( | |
| max((factors[0] - 1.0) / 2.0, 0.001), | |
| max((factors[1] - 1.0) / 2.0, 0.001), | |
| ) | |
| # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma | |
| # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 | |
| # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now | |
| ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) | |
| # Make sure it is odd | |
| if (ks[0] % 2) == 0: | |
| ks = ks[0] + 1, ks[1] | |
| if (ks[1] % 2) == 0: | |
| ks = ks[0], ks[1] + 1 | |
| input = _gaussian_blur2d(input, ks, sigmas) | |
| output = torch.nn.functional.interpolate( | |
| input, size=size, mode=interpolation, align_corners=align_corners) | |
| return output | |
| def _compute_padding(kernel_size): | |
| """Compute padding tuple.""" | |
| # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) | |
| # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | |
| if len(kernel_size) < 2: | |
| raise AssertionError(kernel_size) | |
| computed = [k - 1 for k in kernel_size] | |
| # for even kernels we need to do asymmetric padding :( | |
| out_padding = 2 * len(kernel_size) * [0] | |
| for i in range(len(kernel_size)): | |
| computed_tmp = computed[-(i + 1)] | |
| pad_front = computed_tmp // 2 | |
| pad_rear = computed_tmp - pad_front | |
| out_padding[2 * i + 0] = pad_front | |
| out_padding[2 * i + 1] = pad_rear | |
| return out_padding | |
| def _filter2d(input, kernel): | |
| # prepare kernel | |
| b, c, h, w = input.shape | |
| tmp_kernel = kernel[:, None, ...].to( | |
| device=input.device, dtype=input.dtype) | |
| tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) | |
| height, width = tmp_kernel.shape[-2:] | |
| padding_shape: list[int] = _compute_padding([height, width]) | |
| input = torch.nn.functional.pad(input, padding_shape, mode="reflect") | |
| # kernel and input tensor reshape to align element-wise or batch-wise params | |
| tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) | |
| input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) | |
| # convolve the tensor with the kernel. | |
| output = torch.nn.functional.conv2d( | |
| input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) | |
| out = output.view(b, c, h, w) | |
| return out | |
| def _gaussian(window_size: int, sigma): | |
| if isinstance(sigma, float): | |
| sigma = torch.tensor([[sigma]]) | |
| batch_size = sigma.shape[0] | |
| x = (torch.arange(window_size, device=sigma.device, | |
| dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) | |
| if window_size % 2 == 0: | |
| x = x + 0.5 | |
| gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) | |
| return gauss / gauss.sum(-1, keepdim=True) | |
| def _gaussian_blur2d(input, kernel_size, sigma): | |
| if isinstance(sigma, tuple): | |
| sigma = torch.tensor([sigma], dtype=input.dtype) | |
| else: | |
| sigma = sigma.to(dtype=input.dtype) | |
| ky, kx = int(kernel_size[0]), int(kernel_size[1]) | |
| bs = sigma.shape[0] | |
| kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) | |
| kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) | |
| out_x = _filter2d(input, kernel_x[..., None, :]) | |
| out = _filter2d(out_x, kernel_y[..., None]) | |
| return out | |
| def export_to_video(video_frames, output_video_path, fps): | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| h, w, _ = video_frames[0].shape | |
| video_writer = cv2.VideoWriter( | |
| output_video_path, fourcc, fps=fps, frameSize=(w, h)) | |
| for i in range(len(video_frames)): | |
| img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) | |
| video_writer.write(img) | |
| def export_to_gif(frames, output_gif_path, fps): | |
| """ | |
| Export a list of frames to a GIF. | |
| Args: | |
| - frames (list): List of frames (as numpy arrays or PIL Image objects). | |
| - output_gif_path (str): Path to save the output GIF. | |
| - duration_ms (int): Duration of each frame in milliseconds. | |
| """ | |
| # Convert numpy arrays to PIL Images if needed | |
| pil_frames = [Image.fromarray(frame) if isinstance( | |
| frame, np.ndarray) else frame for frame in frames] | |
| pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'), | |
| format='GIF', | |
| append_images=pil_frames[1:], | |
| save_all=True, | |
| duration=500, | |
| loop=0) | |
| def tensor_to_vae_latent(t, vae, otype="sample"): | |
| video_length = t.shape[1] | |
| t = rearrange(t, "b f c h w -> (b f) c h w") | |
| if otype == "sample": | |
| latents = vae.encode(t).latent_dist.sample() | |
| else: | |
| latents = vae.encode(t).latent_dist.mode() | |
| latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) | |
| latents = latents * vae.config.scaling_factor | |
| return latents | |
| import yaml | |
| def parse_config(config_path="config.yaml"): | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| # handle distributed training rank | |
| env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
| if env_local_rank != -1 and env_local_rank != config.get("local_rank", -1): | |
| config["local_rank"] = env_local_rank | |
| # default fallback: non_ema_revision = revision | |
| if config.get("non_ema_revision") is None: | |
| config["non_ema_revision"] = config.get("revision") | |
| return config | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="SVD Training Script") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="svd/scripts/training/configs/stage1_base.yaml", | |
| help="Path to the config file.", | |
| ) | |
| args = parser.parse_args() | |
| # load YAML and merge into args | |
| config = parse_config(args.config) | |
| # combine yaml + command line args (command line has priority) | |
| for k, v in vars(args).items(): | |
| if v is not None: | |
| config[k] = v | |
| # convert dict to argparse.Namespace for downstream compatibility | |
| args = argparse.Namespace(**config) | |
| print("OUTPUT DIR: ", args.output_dir) | |
| return args | |
| def download_image(url): | |
| original_image = ( | |
| lambda image_url_or_path: load_image(image_url_or_path) | |
| if urlparse(image_url_or_path).scheme | |
| else PIL.Image.open(image_url_or_path).convert("RGB") | |
| )(url) | |
| return original_image | |