Spaces:
Build error
Build error
| import PIL | |
| import torch | |
| import numpy as np | |
| import gsplat as gs | |
| import torch.nn as nn | |
| from copy import deepcopy | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from ops.utils import ( | |
| dpt2xyz, | |
| alpha_inpaint_mask, | |
| transform_points, | |
| numpy_normalize, | |
| numpy_quaternion_from_matrix | |
| ) | |
| class Frame(): | |
| ''' | |
| rgb: in shape of H*W*3, in range of 0-1 | |
| dpt: in shape of H*W, real depth | |
| inpaint: bool mask in shape of H*W for inpainting | |
| intrinsic: 3*3 | |
| extrinsic: array in shape of 4*4 | |
| As a class for: | |
| initialize camera | |
| accept rendering result | |
| accept inpainting result | |
| All at 2D-domain | |
| ''' | |
| def __init__(self, | |
| H: int = None, | |
| W: int = None, | |
| rgb: np.array = None, | |
| dpt: np.array = None, | |
| sky: np.array = None, | |
| inpaint: np.array = None, | |
| intrinsic: np.array = None, | |
| extrinsic: np.array = None, | |
| # detailed target | |
| ideal_dpt: np.array = None, | |
| ideal_nml: np.array = None, | |
| prompt: str = None) -> None: | |
| self.H = H | |
| self.W = W | |
| self.rgb = rgb | |
| self.dpt = dpt | |
| self.sky = sky | |
| self.prompt = prompt | |
| self.intrinsic = intrinsic | |
| self.extrinsic = extrinsic | |
| self._rgb_rect() | |
| self._extr_rect() | |
| # for inpainting | |
| self.inpaint = inpaint | |
| self.inpaint_wo_edge = inpaint | |
| # for supervision | |
| self.ideal_dpt = ideal_dpt | |
| self.ideal_nml = ideal_nml | |
| def _rgb_rect(self): | |
| if self.rgb is not None: | |
| if isinstance(self.rgb, PIL.PngImagePlugin.PngImageFile): | |
| self.rgb = np.array(self.rgb) | |
| if isinstance(self.rgb, PIL.JpegImagePlugin.JpegImageFile): | |
| self.rgb = np.array(self.rgb) | |
| if np.amax(self.rgb) > 1.1: | |
| self.rgb = self.rgb / 255 | |
| def _extr_rect(self): | |
| if self.extrinsic is None: self.extrinsic = np.eye(4) | |
| self.inv_extrinsic = np.linalg.inv(self.extrinsic) | |
| class Gaussian_Frame(): | |
| ''' | |
| In-frame-frustrum | |
| Gaussians from a single RGBD frame | |
| As a class for: | |
| accept information from initialized/inpainting+geo-estimated frame | |
| saving pixelsplat properties including rgb, xyz, scale, rotation, opacity; note here, we made a modification to xyz; | |
| we first project depth to xyz | |
| then we tune a scale map(initialized to ones) and a shift map(initialized to zeros), they are optimized and add to the original xyz when rendering | |
| ''' | |
| # as pixelsplat guassian | |
| rgb: torch.Tensor = None, | |
| scale: torch.Tensor = None, | |
| opacity: torch.Tensor = None, | |
| rotation: torch.Tensor = None, | |
| # gaussian center | |
| dpt: torch.Tensor = None, | |
| xyz: torch.Tensor = None, | |
| # as a frame | |
| H: int = 480, | |
| W: int = 640, | |
| def __init__(self, frame: Frame, device = 'cuda'): | |
| '''after inpainting''' | |
| # de-active functions | |
| self.rgbs_deact = torch.logit | |
| self.scales_deact = torch.log | |
| self.opacity_deact = torch.logit | |
| self.device = device | |
| # for gaussian initialization | |
| self._set_property_from_frame(frame) | |
| def _to_3d(self): | |
| # inv intrinsic | |
| xyz = dpt2xyz(self.dpt,self.intrinsic) | |
| inv_extrinsic = np.linalg.inv(self.extrinsic) | |
| xyz = transform_points(xyz,inv_extrinsic) | |
| return xyz | |
| def _paint_filter(self,paint_mask): | |
| if np.sum(paint_mask)<3: | |
| paint_mask = np.zeros((self.H,self.W)) | |
| paint_mask[0:1] = 1 | |
| paint_mask = paint_mask>.5 | |
| self.rgb = self.rgb[paint_mask] | |
| self.xyz = self.xyz[paint_mask] | |
| self.scale = self.scale[paint_mask] | |
| self.opacity = self.opacity[paint_mask] | |
| self.rotation = self.rotation[paint_mask] | |
| def _to_cuda(self): | |
| self.rgb = torch.from_numpy(self.rgb.astype(np.float32)).to(self.device) | |
| self.xyz = torch.from_numpy(self.xyz.astype(np.float32)).to(self.device) | |
| self.scale = torch.from_numpy(self.scale.astype(np.float32)).to(self.device) | |
| self.opacity = torch.from_numpy(self.opacity.astype(np.float32)).to(self.device) | |
| self.rotation = torch.from_numpy(self.rotation.astype(np.float32)).to(self.device) | |
| def _fine_init_scale_rotations(self): | |
| # from https://arxiv.org/pdf/2406.09394 | |
| """ Compute rotation matrices that align z-axis with given normal vectors using matrix operations. """ | |
| up_axis = np.array([0,1,0]) | |
| nml = self.nml @ self.extrinsic[0:3,0:3] | |
| qz = numpy_normalize(nml) | |
| qx = np.cross(up_axis,qz) | |
| qx = numpy_normalize(qx) | |
| qy = np.cross(qz,qx) | |
| qy = numpy_normalize(qy) | |
| rot = np.concatenate([qx[...,None],qy[...,None],qz[...,None]],axis=-1) | |
| self.rotation = numpy_quaternion_from_matrix(rot) | |
| # scale | |
| safe_nml = deepcopy(self.nml) | |
| safe_nml[safe_nml[:,:,-1]<0.2,-1] = .2 | |
| normal_xoz = deepcopy(safe_nml) | |
| normal_yoz = deepcopy(safe_nml) | |
| normal_xoz[...,1] = 0. | |
| normal_yoz[...,0] = 0. | |
| normal_xoz = numpy_normalize(normal_xoz) | |
| normal_yoz = numpy_normalize(normal_yoz) | |
| cos_theta_x = np.abs(normal_xoz[...,2]) | |
| cos_theta_y = np.abs(normal_yoz[...,2]) | |
| scale_basic = self.dpt / self.intrinsic[0,0] / np.sqrt(2) | |
| scale_x = scale_basic / cos_theta_x | |
| scale_y = scale_basic / cos_theta_y | |
| scale_z = (scale_x + scale_y) / 10. | |
| self.scale = np.concatenate([scale_x[...,None], | |
| scale_y[...,None], | |
| scale_z[...,None]],axis=-1) | |
| def _coarse_init_scale_rotations(self): | |
| # gaussian property -- HW3 scale | |
| self.scale = self.dpt / self.intrinsic[0,0] / np.sqrt(2) | |
| self.scale = self.scale[:,:,None].repeat(3,-1) | |
| # gaussian property -- HW4 rotation | |
| self.rotation = np.zeros((self.H,self.W,4)) | |
| self.rotation[:,:,0] = 1. | |
| def _set_property_from_frame(self,frame: Frame): | |
| '''frame here is a complete init/inpainted frame''' | |
| # basic frame-level property | |
| self.H = frame.H | |
| self.W = frame.W | |
| self.dpt = frame.dpt | |
| self.intrinsic = frame.intrinsic | |
| self.extrinsic = frame.extrinsic | |
| # gaussian property -- xyz with train-able pixel-aligned scale and shift | |
| self.xyz = self._to_3d() | |
| # gaussian property -- HW3 rgb | |
| self.rgb = frame.rgb | |
| # gaussian property -- HW4 rotation HW3 scale | |
| self._coarse_init_scale_rotations() | |
| # gaussian property -- HW opacity | |
| self.opacity = np.ones((self.H,self.W,1)) * 0.8 | |
| # to cuda | |
| self._paint_filter(frame.inpaint_wo_edge) | |
| self._to_cuda() | |
| # de-activate | |
| self.rgb = self.rgbs_deact(self.rgb) | |
| self.scale = self.scales_deact(self.scale) | |
| self.opacity = self.opacity_deact(self.opacity) | |
| # to torch parameters | |
| self.rgb = nn.Parameter(self.rgb,requires_grad=False) | |
| self.xyz = nn.Parameter(self.xyz,requires_grad=False) | |
| self.scale = nn.Parameter(self.scale,requires_grad=False) | |
| self.opacity = nn.Parameter(self.opacity,requires_grad=False) | |
| self.rotation = nn.Parameter(self.rotation,requires_grad=False) | |
| def _require_grad(self,sign=True): | |
| self.rgb = self.rgb.requires_grad_(sign) | |
| self.xyz = self.xyz.requires_grad_(sign) | |
| self.scale = self.scale.requires_grad_(sign) | |
| self.opacity = self.opacity.requires_grad_(sign) | |
| self.rotation = self.rotation.requires_grad_(sign) | |
| class Gaussian_Scene(): | |
| def __init__(self,cfg=None): | |
| # frames initialing the frame | |
| self.frames = [] | |
| self.gaussian_frames: list[Gaussian_Frame] = [] # gaussian frame require training at this optimization | |
| # activate fuctions | |
| self.rgbs_act = torch.sigmoid | |
| self.scales_act = torch.exp | |
| self.opacity_act = torch.sigmoid | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # for traj generation | |
| self.traj_type = 'spiral' | |
| if cfg is not None: | |
| self.traj_min_percentage = cfg.scene.traj.near_percentage | |
| self.traj_max_percentage = cfg.scene.traj.far_percentage | |
| self.traj_forward_ratio = cfg.scene.traj.traj_forward_ratio | |
| self.traj_backward_ratio = cfg.scene.traj.traj_backward_ratio | |
| else: | |
| self.traj_min_percentage,self.traj_max_percentage,self.traj_forward_ratio,self.traj_backward_ratio = 5, 50, 0.3, 0.4 | |
| # basic operations | |
| def _render_RGBD(self,frame,background_color='black'): | |
| ''' | |
| :intinsic: tensor of [fu,fv,cu,cv] 4-dimension | |
| :extinsic: tensor 4*4-dimension | |
| :out: tensor H*W*3-dimension | |
| ''' | |
| background = None | |
| if background_color =='white': | |
| background = torch.ones(1,4,device=self.device)*0.1 | |
| background[:,-1] = 0. # for depth | |
| # aligned untrainable xyz and unaligned trainable xyz | |
| # others | |
| xyz = torch.cat([gf.xyz.reshape(-1,3) for gf in self.gaussian_frames],dim=0) | |
| rgb = torch.cat([gf.rgb.reshape(-1,3) for gf in self.gaussian_frames],dim=0) | |
| scale = torch.cat([gf.scale.reshape(-1,3) for gf in self.gaussian_frames],dim=0) | |
| opacity = torch.cat([gf.opacity.reshape(-1) for gf in self.gaussian_frames],dim=0) | |
| rotation = torch.cat([gf.rotation.reshape(-1,4) for gf in self.gaussian_frames],dim=0) | |
| # activate | |
| rgb = self.rgbs_act(rgb) | |
| scale = self.scales_act(scale) | |
| rotation = F.normalize(rotation,dim=1) | |
| opacity = self.opacity_act(opacity) | |
| # property | |
| H,W = frame.H, frame.W | |
| intrinsic = torch.from_numpy(frame.intrinsic.astype(np.float32)).to(self.device) | |
| extrinsic = torch.from_numpy(frame.extrinsic.astype(np.float32)).to(self.device) | |
| # render | |
| render_out,render_alpha,_ = gs.rendering.rasterization(means = xyz, | |
| scales = scale, | |
| quats = rotation, | |
| opacities = opacity, | |
| colors = rgb, | |
| Ks = intrinsic[None], | |
| viewmats = extrinsic[None], | |
| width = W, | |
| height = H, | |
| packed = False, | |
| near_plane= 0.01, | |
| render_mode="RGB+ED", | |
| backgrounds=background) # render: 1*H*W*(3+1) | |
| render_out = render_out.squeeze() # result: H*W*(3+1) | |
| render_rgb = render_out[:,:,0:3] | |
| render_dpt = render_out[:,:,-1] | |
| return render_rgb, render_dpt, render_alpha | |
| def _render_for_inpaint(self,frame): | |
| # first render | |
| render_rgb, render_dpt, render_alpha = self._render_RGBD(frame) | |
| render_msk = alpha_inpaint_mask(render_alpha) | |
| # to numpy | |
| render_rgb = render_rgb.detach().cpu().numpy() | |
| render_dpt = render_dpt.detach().cpu().numpy() | |
| render_alpha = render_alpha.detach().cpu().numpy() | |
| # assign back | |
| frame.rgb = render_rgb | |
| frame.dpt = render_dpt | |
| frame.inpaint = render_msk | |
| return frame | |
| def _add_trainable_frame(self,frame:Frame,require_grad=True): | |
| # for the init frame, we keep all pixels for finetuning | |
| self.frames.append(frame) | |
| gf = Gaussian_Frame(frame, self.device) | |
| gf._require_grad(require_grad) | |
| self.gaussian_frames.append(gf) | |