Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from torchvision import transforms | |
| class GroupResize: | |
| def __init__(self, size: int = 256) -> None: | |
| self.transform = transforms.Resize(size) | |
| def __call__( | |
| self, img_tuple: Tuple[torch.Tensor, torch.Tensor] | |
| ) -> Tuple[List[torch.Tensor], torch.Tensor]: | |
| img_group, label = img_tuple | |
| return [self.transform(img) for img in img_group], label | |
| class GroupNormalize: | |
| def __init__(self, mean: List[float], std: List[float]) -> None: | |
| self.mean = mean | |
| self.std = std | |
| def __call__( | |
| self, tensor_tuple: Tuple[torch.Tensor, torch.Tensor] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| tensor, label = tensor_tuple | |
| rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) | |
| rep_std = self.std * (tensor.size()[0] // len(self.std)) | |
| for t, m, s in zip(tensor, rep_mean, rep_std): | |
| t.sub_(m).div_(s) | |
| return tensor, label | |
| class GroupCenterCrop: | |
| def __init__(self, size: int) -> None: | |
| self.worker = transforms.CenterCrop(size) | |
| def __call__( | |
| self, img_tuple: Tuple[torch.Tensor, torch.Tensor] | |
| ) -> Tuple[List[torch.Tensor], torch.Tensor]: | |
| img_group, label = img_tuple | |
| return [self.worker(img) for img in img_group], label | |
| class Stack: | |
| def __init__(self, roll: Optional[bool] = False) -> None: | |
| self.roll = roll | |
| def __call__(self, img_tuple: Tuple[torch.Tensor, torch.Tensor]): | |
| img_group, label = img_tuple | |
| if img_group[0].mode == "L": | |
| return ( | |
| np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), | |
| label, | |
| ) | |
| elif img_group[0].mode == "RGB": | |
| if self.roll: | |
| return ( | |
| np.concatenate( | |
| [np.array(x)[:, :, ::-1] for x in img_group], axis=2 | |
| ), | |
| label, | |
| ) | |
| else: | |
| return np.concatenate(img_group, axis=2), label | |
| class ToTorchFormatTensor: | |
| def __init__(self, div: Optional[bool] = True) -> None: | |
| self.div = div | |
| def __call__( | |
| self, pic_tuple: Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| pic, label = pic_tuple | |
| if isinstance(pic, np.ndarray): | |
| # handle numpy array | |
| img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() | |
| elif isinstance(pic, Image.Image): | |
| # handle PIL Image | |
| img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | |
| img = img.view(pic.size[1], pic.size[0], len(pic.mode)) | |
| # put it from HWC to CHW format | |
| # yikes, this transpose takes 80% of the loading time/CPU | |
| img = img.transpose(0, 1).transpose(0, 2).contiguous() | |
| else: | |
| raise TypeError( | |
| f"Unsupported type {type(pic)} must be np.ndarray or torch.Tensor" | |
| ) | |
| return img.float().div(255.0) if self.div else img.float(), label | |
| class TubeMaskingGenerator: | |
| def __init__(self, input_size: Tuple[int, int, int], mask_ratio: float) -> None: | |
| self.frames, self.height, self.width = input_size | |
| self.num_patches_per_frame = self.height * self.width | |
| self.total_patches = self.frames * self.num_patches_per_frame | |
| self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) | |
| self.total_masks = self.frames * self.num_masks_per_frame | |
| def __call__(self): | |
| mask_per_frame = np.hstack( | |
| [ | |
| np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), | |
| np.ones(self.num_masks_per_frame), | |
| ] | |
| ) | |
| np.random.shuffle(mask_per_frame) | |
| mask = np.tile(mask_per_frame, (self.frames, 1)).flatten() | |
| return mask | |
| def get_videomae_transform(input_size: int = 224) -> "transforms.Compose": | |
| return transforms.Compose( | |
| [ | |
| GroupResize(size=384), | |
| GroupCenterCrop(input_size), | |
| Stack(roll=False), | |
| ToTorchFormatTensor(div=True), | |
| GroupNormalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
| ] | |
| ) | |