Spaces:
Runtime error
Runtime error
| from typing import List, Tuple | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from decord import VideoReader, cpu | |
| from einops import rearrange | |
| from PIL import Image | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from torchvision import transforms | |
| from torchvision.transforms import ToPILImage | |
| def get_frames( | |
| path: str, transform: transforms.Compose, num_frames: int = 16 | |
| ) -> Tuple[torch.Tensor, List[int]]: | |
| vr = VideoReader(path, ctx=cpu(0)) | |
| tmp = np.arange(0, num_frames * 2, 2) + 60 | |
| frame_id_list = tmp.tolist() | |
| video_data = vr.get_batch(frame_id_list).asnumpy() | |
| frames, _ = transform( | |
| ( | |
| [ | |
| Image.fromarray(video_data[vid, :, :, :]).convert("RGB") | |
| for vid, _ in enumerate(frame_id_list) | |
| ], | |
| None, | |
| ) | |
| ) | |
| frames = frames.view((num_frames, 3) + frames.size()[-2:]).transpose(0, 1) | |
| return frames, frame_id_list | |
| def prepare_frames_masks( | |
| frames: torch.Tensor, masks: torch.Tensor, device: "torch.device" | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| frames = frames.unsqueeze(0) | |
| masks = masks.unsqueeze(0) | |
| frames = frames.to(device, non_blocking=True) | |
| masks = masks.to(device, non_blocking=True).flatten(1).to(torch.bool) | |
| return frames, masks | |
| def get_videomae_outputs( | |
| frames: torch.Tensor, | |
| masks: torch.Tensor, | |
| outputs: torch.Tensor, | |
| ids: List[int], | |
| patch_size: Tuple[int, ...], | |
| device: "torch.device", | |
| ): | |
| visualisations = [] | |
| mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] | |
| std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] | |
| ori_img = frames * std + mean # in [0, 1] | |
| original_images = [ | |
| ToPILImage()(ori_img[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids) | |
| ] | |
| img_squeeze = rearrange( | |
| ori_img, | |
| "b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c", | |
| p0=2, | |
| p1=patch_size[0], | |
| p2=patch_size[0], | |
| ) | |
| img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / ( | |
| img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6 | |
| ) | |
| img_patch = rearrange(img_norm, "b n p c -> b n (p c)") | |
| img_patch[masks] = outputs | |
| # make mask | |
| mask = torch.ones_like(img_patch) | |
| mask[masks] = 0 | |
| mask = rearrange(mask, "b n (p c) -> b n p c", c=3) | |
| mask = rearrange( | |
| mask, | |
| "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ", | |
| p0=2, | |
| p1=patch_size[0], | |
| p2=patch_size[1], | |
| h=14, | |
| w=14, | |
| ) | |
| # save reconstruction video | |
| rec_img = rearrange(img_patch, "b n (p c) -> b n p c", c=3) | |
| rec_img = rec_img * ( | |
| img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6 | |
| ) + img_squeeze.mean(dim=-2, keepdim=True) | |
| rec_img = rearrange( | |
| rec_img, | |
| "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", | |
| p0=2, | |
| p1=patch_size[0], | |
| p2=patch_size[1], | |
| h=14, | |
| w=14, | |
| ) | |
| reconstructed_images = [ | |
| ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0, 0.996)) | |
| for vid, _ in enumerate(ids) | |
| ] | |
| # save masked video | |
| img_mask = rec_img * mask | |
| masked_images = [ | |
| ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids) | |
| ] | |
| assert len(original_images) == len(reconstructed_images) == len(masked_images) | |
| for i in range(len(original_images)): | |
| visualisations.append( | |
| [original_images[i], masked_images[i], reconstructed_images[i]] | |
| ) | |
| return visualisations | |
| def create_plot(images): | |
| num_cols = 3 | |
| num_rows = 16 | |
| column_names = ["Original Patch", "Masked Patch", "Reconstructed Patch"] | |
| fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 48)) | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| axes[i, j].imshow(images[i][j]) | |
| axes[i, j].axis("off") | |
| if i == 0: | |
| axes[i, j].set_title(column_names[j], fontsize=16) | |
| plt.tight_layout() | |
| return fig | |