Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from networks.layers.basic import DropPath, GroupNorm1D, GNActDWConv2d, seq_to_2d, ScaleOffset, mask_out | |
| from networks.layers.attention import silu, MultiheadAttention, MultiheadLocalAttentionV2, MultiheadLocalAttentionV3, GatedPropagation, LocalGatedPropagation | |
| def _get_norm(indim, type='ln', groups=8): | |
| if type == 'gn': | |
| return GroupNorm1D(indim, groups) | |
| else: | |
| return nn.LayerNorm(indim) | |
| def _get_activation_fn(activation): | |
| """Return an activation function given a string""" | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| raise RuntimeError( | |
| F"activation should be relu/gele/glu, not {activation}.") | |
| class LongShortTermTransformer(nn.Module): | |
| def __init__(self, | |
| num_layers=2, | |
| d_model=256, | |
| self_nhead=8, | |
| att_nhead=8, | |
| dim_feedforward=1024, | |
| emb_dropout=0., | |
| droppath=0.1, | |
| lt_dropout=0., | |
| st_dropout=0., | |
| droppath_lst=False, | |
| droppath_scaling=False, | |
| activation="gelu", | |
| return_intermediate=False, | |
| intermediate_norm=True, | |
| final_norm=True, | |
| block_version="v1"): | |
| super().__init__() | |
| self.intermediate_norm = intermediate_norm | |
| self.final_norm = final_norm | |
| self.num_layers = num_layers | |
| self.return_intermediate = return_intermediate | |
| self.emb_dropout = nn.Dropout(emb_dropout, True) | |
| self.mask_token = nn.Parameter(torch.randn([1, 1, d_model])) | |
| if block_version == "v1": | |
| block = LongShortTermTransformerBlock | |
| elif block_version == "v2": | |
| block = LongShortTermTransformerBlockV2 | |
| elif block_version == "v3": | |
| block = LongShortTermTransformerBlockV3 | |
| else: | |
| raise NotImplementedError | |
| layers = [] | |
| for idx in range(num_layers): | |
| if droppath_scaling: | |
| if num_layers == 1: | |
| droppath_rate = 0 | |
| else: | |
| droppath_rate = droppath * idx / (num_layers - 1) | |
| else: | |
| droppath_rate = droppath | |
| layers.append( | |
| block(d_model, self_nhead, att_nhead, dim_feedforward, | |
| droppath_rate, lt_dropout, st_dropout, droppath_lst, | |
| activation)) | |
| self.layers = nn.ModuleList(layers) | |
| num_norms = num_layers - 1 if intermediate_norm else 0 | |
| if final_norm: | |
| num_norms += 1 | |
| self.decoder_norms = [ | |
| _get_norm(d_model, type='ln') for _ in range(num_norms) | |
| ] if num_norms > 0 else None | |
| if self.decoder_norms is not None: | |
| self.decoder_norms = nn.ModuleList(self.decoder_norms) | |
| def forward(self, | |
| tgt, | |
| long_term_memories, | |
| short_term_memories, | |
| curr_id_emb=None, | |
| self_pos=None, | |
| size_2d=None): | |
| output = self.emb_dropout(tgt) | |
| # output = mask_out(output, self.mask_token, 0.15, self.training) | |
| intermediate = [] | |
| intermediate_memories = [] | |
| for idx, layer in enumerate(self.layers): | |
| output, memories = layer(output, | |
| long_term_memories[idx] if | |
| long_term_memories is not None else None, | |
| short_term_memories[idx] if | |
| short_term_memories is not None else None, | |
| curr_id_emb=curr_id_emb, | |
| self_pos=self_pos, | |
| size_2d=size_2d) | |
| if self.return_intermediate: | |
| intermediate.append(output) | |
| intermediate_memories.append(memories) | |
| if self.decoder_norms is not None: | |
| if self.final_norm: | |
| output = self.decoder_norms[-1](output) | |
| if self.return_intermediate: | |
| intermediate.pop() | |
| intermediate.append(output) | |
| if self.intermediate_norm: | |
| for idx in range(len(intermediate) - 1): | |
| intermediate[idx] = self.decoder_norms[idx]( | |
| intermediate[idx]) | |
| if self.return_intermediate: | |
| return intermediate, intermediate_memories | |
| return output, memories | |
| class DualBranchGPM(nn.Module): | |
| def __init__(self, | |
| num_layers=2, | |
| d_model=256, | |
| self_nhead=8, | |
| att_nhead=8, | |
| dim_feedforward=1024, | |
| emb_dropout=0., | |
| droppath=0.1, | |
| lt_dropout=0., | |
| st_dropout=0., | |
| droppath_lst=False, | |
| droppath_scaling=False, | |
| activation="gelu", | |
| return_intermediate=False, | |
| intermediate_norm=True, | |
| final_norm=True): | |
| super().__init__() | |
| self.intermediate_norm = intermediate_norm | |
| self.final_norm = final_norm | |
| self.num_layers = num_layers | |
| self.return_intermediate = return_intermediate | |
| self.emb_dropout = nn.Dropout(emb_dropout, True) | |
| # self.mask_token = nn.Parameter(torch.randn([1, 1, d_model])) | |
| block = GatedPropagationModule | |
| layers = [] | |
| for idx in range(num_layers): | |
| if droppath_scaling: | |
| if num_layers == 1: | |
| droppath_rate = 0 | |
| else: | |
| droppath_rate = droppath * idx / (num_layers - 1) | |
| else: | |
| droppath_rate = droppath | |
| layers.append( | |
| block(d_model, | |
| self_nhead, | |
| att_nhead, | |
| dim_feedforward, | |
| droppath_rate, | |
| lt_dropout, | |
| st_dropout, | |
| droppath_lst, | |
| activation, | |
| layer_idx=idx)) | |
| self.layers = nn.ModuleList(layers) | |
| num_norms = num_layers - 1 if intermediate_norm else 0 | |
| if final_norm: | |
| num_norms += 1 | |
| self.decoder_norms = [ | |
| _get_norm(d_model * 2, type='gn', groups=2) | |
| for _ in range(num_norms) | |
| ] if num_norms > 0 else None | |
| if self.decoder_norms is not None: | |
| self.decoder_norms = nn.ModuleList(self.decoder_norms) | |
| def forward(self, | |
| tgt, | |
| long_term_memories, | |
| short_term_memories, | |
| curr_id_emb=None, | |
| self_pos=None, | |
| size_2d=None): | |
| output = self.emb_dropout(tgt) | |
| # output = mask_out(output, self.mask_token, 0.15, self.training) | |
| intermediate = [] | |
| intermediate_memories = [] | |
| output_id = None | |
| for idx, layer in enumerate(self.layers): | |
| output, output_id, memories = layer( | |
| output, | |
| output_id, | |
| long_term_memories[idx] | |
| if long_term_memories is not None else None, | |
| short_term_memories[idx] | |
| if short_term_memories is not None else None, | |
| curr_id_emb=curr_id_emb, | |
| self_pos=self_pos, | |
| size_2d=size_2d) | |
| cat_output = torch.cat([output, output_id], dim=2) | |
| if self.return_intermediate: | |
| intermediate.append(cat_output) | |
| intermediate_memories.append(memories) | |
| if self.decoder_norms is not None: | |
| if self.final_norm: | |
| cat_output = self.decoder_norms[-1](cat_output) | |
| if self.return_intermediate: | |
| intermediate.pop() | |
| intermediate.append(cat_output) | |
| if self.intermediate_norm: | |
| for idx in range(len(intermediate) - 1): | |
| intermediate[idx] = self.decoder_norms[idx]( | |
| intermediate[idx]) | |
| if self.return_intermediate: | |
| return intermediate, intermediate_memories | |
| return cat_output, memories | |
| class LongShortTermTransformerBlock(nn.Module): | |
| def __init__(self, | |
| d_model, | |
| self_nhead, | |
| att_nhead, | |
| dim_feedforward=1024, | |
| droppath=0.1, | |
| lt_dropout=0., | |
| st_dropout=0., | |
| droppath_lst=False, | |
| activation="gelu", | |
| local_dilation=1, | |
| enable_corr=True): | |
| super().__init__() | |
| # Long Short-Term Attention | |
| self.norm1 = _get_norm(d_model) | |
| self.linear_Q = nn.Linear(d_model, d_model) | |
| self.linear_V = nn.Linear(d_model, d_model) | |
| self.long_term_attn = MultiheadAttention(d_model, | |
| att_nhead, | |
| use_linear=False, | |
| dropout=lt_dropout) | |
| # MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3 | |
| if enable_corr: | |
| try: | |
| import spatial_correlation_sampler | |
| MultiheadLocalAttention = MultiheadLocalAttentionV2 | |
| except Exception as inst: | |
| print(inst) | |
| print("Failed to import PyTorch Correlation, For better efficiency, please install it.") | |
| MultiheadLocalAttention = MultiheadLocalAttentionV3 | |
| else: | |
| MultiheadLocalAttention = MultiheadLocalAttentionV3 | |
| self.short_term_attn = MultiheadLocalAttention(d_model, | |
| att_nhead, | |
| dilation=local_dilation, | |
| use_linear=False, | |
| dropout=st_dropout) | |
| self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) | |
| self.droppath_lst = droppath_lst | |
| # Self-attention | |
| self.norm2 = _get_norm(d_model) | |
| self.self_attn = MultiheadAttention(d_model, self_nhead) | |
| # Feed-forward | |
| self.norm3 = _get_norm(d_model) | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.activation = GNActDWConv2d(dim_feedforward) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.droppath = DropPath(droppath, batch_dim=1) | |
| self._init_weight() | |
| def with_pos_embed(self, tensor, pos=None): | |
| size = tensor.size() | |
| if len(size) == 4 and pos is not None: | |
| n, c, h, w = size | |
| pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) | |
| return tensor if pos is None else tensor + pos | |
| def forward(self, | |
| tgt, | |
| long_term_memory=None, | |
| short_term_memory=None, | |
| curr_id_emb=None, | |
| self_pos=None, | |
| size_2d=(30, 30)): | |
| # Self-attention | |
| _tgt = self.norm1(tgt) | |
| q = k = self.with_pos_embed(_tgt, self_pos) | |
| v = _tgt | |
| tgt2 = self.self_attn(q, k, v)[0] | |
| tgt = tgt + self.droppath(tgt2) | |
| # Long Short-Term Attention | |
| _tgt = self.norm2(tgt) | |
| curr_Q = self.linear_Q(_tgt) | |
| curr_K = curr_Q | |
| curr_V = _tgt | |
| local_Q = seq_to_2d(curr_Q, size_2d) | |
| if curr_id_emb is not None: | |
| global_K, global_V = self.fuse_key_value_id( | |
| curr_K, curr_V, curr_id_emb) | |
| local_K = seq_to_2d(global_K, size_2d) | |
| local_V = seq_to_2d(global_V, size_2d) | |
| else: | |
| global_K, global_V = long_term_memory | |
| local_K, local_V = short_term_memory | |
| tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] | |
| tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] | |
| if self.droppath_lst: | |
| tgt = tgt + self.droppath(tgt2 + tgt3) | |
| else: | |
| tgt = tgt + self.lst_dropout(tgt2 + tgt3) | |
| # Feed-forward | |
| _tgt = self.norm3(tgt) | |
| tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) | |
| tgt = tgt + self.droppath(tgt2) | |
| return tgt, [[curr_K, curr_V], [global_K, global_V], | |
| [local_K, local_V]] | |
| def fuse_key_value_id(self, key, value, id_emb): | |
| K = key | |
| V = self.linear_V(value + id_emb) | |
| return K, V | |
| def _init_weight(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| class LongShortTermTransformerBlockV2(nn.Module): | |
| def __init__(self, | |
| d_model, | |
| self_nhead, | |
| att_nhead, | |
| dim_feedforward=1024, | |
| droppath=0.1, | |
| lt_dropout=0., | |
| st_dropout=0., | |
| droppath_lst=False, | |
| activation="gelu", | |
| local_dilation=1, | |
| enable_corr=True): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.att_nhead = att_nhead | |
| # Self-attention | |
| self.norm1 = _get_norm(d_model) | |
| self.self_attn = MultiheadAttention(d_model, self_nhead) | |
| # Long Short-Term Attention | |
| self.norm2 = _get_norm(d_model) | |
| self.linear_QV = nn.Linear(d_model, 2 * d_model) | |
| self.linear_ID_KV = nn.Linear(d_model, d_model + att_nhead) | |
| self.long_term_attn = MultiheadAttention(d_model, | |
| att_nhead, | |
| use_linear=False, | |
| dropout=lt_dropout) | |
| # MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3 | |
| if enable_corr: | |
| try: | |
| import spatial_correlation_sampler | |
| MultiheadLocalAttention = MultiheadLocalAttentionV2 | |
| except Exception as inst: | |
| print(inst) | |
| print("Failed to import PyTorch Correlation, For better efficiency, please install it.") | |
| MultiheadLocalAttention = MultiheadLocalAttentionV3 | |
| else: | |
| MultiheadLocalAttention = MultiheadLocalAttentionV3 | |
| self.short_term_attn = MultiheadLocalAttention(d_model, | |
| att_nhead, | |
| dilation=local_dilation, | |
| use_linear=False, | |
| dropout=st_dropout) | |
| self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) | |
| self.droppath_lst = droppath_lst | |
| # Feed-forward | |
| self.norm3 = _get_norm(d_model) | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.activation = GNActDWConv2d(dim_feedforward) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.droppath = DropPath(droppath, batch_dim=1) | |
| self._init_weight() | |
| def with_pos_embed(self, tensor, pos=None): | |
| size = tensor.size() | |
| if len(size) == 4 and pos is not None: | |
| n, c, h, w = size | |
| pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) | |
| return tensor if pos is None else tensor + pos | |
| def forward(self, | |
| tgt, | |
| long_term_memory=None, | |
| short_term_memory=None, | |
| curr_id_emb=None, | |
| self_pos=None, | |
| size_2d=(30, 30)): | |
| # Self-attention | |
| _tgt = self.norm1(tgt) | |
| q = k = self.with_pos_embed(_tgt, self_pos) | |
| v = _tgt | |
| tgt2 = self.self_attn(q, k, v)[0] | |
| tgt = tgt + self.droppath(tgt2) | |
| # Long Short-Term Attention | |
| _tgt = self.norm2(tgt) | |
| curr_QV = self.linear_QV(_tgt) | |
| curr_QV = torch.split(curr_QV, self.d_model, dim=2) | |
| curr_Q = curr_K = curr_QV[0] | |
| curr_V = curr_QV[1] | |
| local_Q = seq_to_2d(curr_Q, size_2d) | |
| if curr_id_emb is not None: | |
| global_K, global_V = self.fuse_key_value_id( | |
| curr_K, curr_V, curr_id_emb) | |
| local_K = seq_to_2d(global_K, size_2d) | |
| local_V = seq_to_2d(global_V, size_2d) | |
| else: | |
| global_K, global_V = long_term_memory | |
| local_K, local_V = short_term_memory | |
| tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] | |
| tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] | |
| if self.droppath_lst: | |
| tgt = tgt + self.droppath(tgt2 + tgt3) | |
| else: | |
| tgt = tgt + self.lst_dropout(tgt2 + tgt3) | |
| # Feed-forward | |
| _tgt = self.norm3(tgt) | |
| tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) | |
| tgt = tgt + self.droppath(tgt2) | |
| return tgt, [[curr_K, curr_V], [global_K, global_V], | |
| [local_K, local_V]] | |
| def fuse_key_value_id(self, key, value, id_emb): | |
| ID_KV = self.linear_ID_KV(id_emb) | |
| ID_K, ID_V = torch.split(ID_KV, [self.att_nhead, self.d_model], dim=2) | |
| bs = key.size(1) | |
| K = key.view(-1, bs, self.att_nhead, self.d_model // | |
| self.att_nhead) * (1 + torch.tanh(ID_K)).unsqueeze(-1) | |
| K = K.view(-1, bs, self.d_model) | |
| V = value + ID_V | |
| return K, V | |
| def _init_weight(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| class GatedPropagationModule(nn.Module): | |
| def __init__(self, | |
| d_model, | |
| self_nhead, | |
| att_nhead, | |
| dim_feedforward=1024, | |
| droppath=0.1, | |
| lt_dropout=0., | |
| st_dropout=0., | |
| droppath_lst=False, | |
| activation="gelu", | |
| local_dilation=1, | |
| enable_corr=True, | |
| max_local_dis=7, | |
| layer_idx=0, | |
| expand_ratio=2.): | |
| super().__init__() | |
| expand_ratio = expand_ratio | |
| expand_d_model = int(d_model * expand_ratio) | |
| self.expand_d_model = expand_d_model | |
| self.d_model = d_model | |
| self.att_nhead = att_nhead | |
| d_att = d_model // 2 if att_nhead == 1 else d_model // att_nhead | |
| self.d_att = d_att | |
| self.layer_idx = layer_idx | |
| # Long Short-Term Attention | |
| self.norm1 = _get_norm(d_model) | |
| self.linear_QV = nn.Linear(d_model, d_att * att_nhead + expand_d_model) | |
| self.linear_U = nn.Linear(d_model, expand_d_model) | |
| if layer_idx == 0: | |
| self.linear_ID_V = nn.Linear(d_model, expand_d_model) | |
| else: | |
| self.id_norm1 = _get_norm(d_model) | |
| self.linear_ID_V = nn.Linear(d_model * 2, expand_d_model) | |
| self.linear_ID_U = nn.Linear(d_model, expand_d_model) | |
| self.long_term_attn = GatedPropagation(d_qk=self.d_model, | |
| d_vu=self.d_model * 2, | |
| num_head=att_nhead, | |
| use_linear=False, | |
| dropout=lt_dropout, | |
| d_att=d_att, | |
| top_k=-1, | |
| expand_ratio=expand_ratio) | |
| if enable_corr: | |
| try: | |
| import spatial_correlation_sampler | |
| except Exception as inst: | |
| print(inst) | |
| print("Failed to import PyTorch Correlation, For better efficiency, please install it.") | |
| enable_corr = False | |
| self.short_term_attn = LocalGatedPropagation(d_qk=self.d_model, | |
| d_vu=self.d_model * 2, | |
| num_head=att_nhead, | |
| dilation=local_dilation, | |
| use_linear=False, | |
| enable_corr=enable_corr, | |
| dropout=st_dropout, | |
| d_att=d_att, | |
| max_dis=max_local_dis, | |
| expand_ratio=expand_ratio) | |
| self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) | |
| self.droppath_lst = droppath_lst | |
| # Self-attention | |
| self.norm2 = _get_norm(d_model) | |
| self.id_norm2 = _get_norm(d_model) | |
| self.self_attn = GatedPropagation(d_model * 2, | |
| d_model * 2, | |
| self_nhead, | |
| d_att=d_att) | |
| self.droppath = DropPath(droppath, batch_dim=1) | |
| self._init_weight() | |
| def with_pos_embed(self, tensor, pos=None): | |
| size = tensor.size() | |
| if len(size) == 4 and pos is not None: | |
| n, c, h, w = size | |
| pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) | |
| return tensor if pos is None else tensor + pos | |
| def forward(self, | |
| tgt, | |
| tgt_id=None, | |
| long_term_memory=None, | |
| short_term_memory=None, | |
| curr_id_emb=None, | |
| self_pos=None, | |
| size_2d=(30, 30)): | |
| # Long Short-Term Attention | |
| _tgt = self.norm1(tgt) | |
| curr_QV = self.linear_QV(_tgt) | |
| curr_QV = torch.split( | |
| curr_QV, [self.d_att * self.att_nhead, self.expand_d_model], dim=2) | |
| curr_Q = curr_K = curr_QV[0] | |
| local_Q = seq_to_2d(curr_Q, size_2d) | |
| curr_V = silu(curr_QV[1]) | |
| curr_U = self.linear_U(_tgt) | |
| if tgt_id is None: | |
| tgt_id = 0 | |
| cat_curr_U = torch.cat( | |
| [silu(curr_U), torch.ones_like(curr_U)], dim=-1) | |
| curr_ID_V = None | |
| else: | |
| _tgt_id = self.id_norm1(tgt_id) | |
| curr_ID_V = _tgt_id | |
| curr_ID_U = self.linear_ID_U(_tgt_id) | |
| cat_curr_U = silu(torch.cat([curr_U, curr_ID_U], dim=-1)) | |
| if curr_id_emb is not None: | |
| global_K, global_V = curr_K, curr_V | |
| local_K = seq_to_2d(global_K, size_2d) | |
| local_V = seq_to_2d(global_V, size_2d) | |
| _, global_ID_V = self.fuse_key_value_id(None, curr_ID_V, | |
| curr_id_emb) | |
| local_ID_V = seq_to_2d(global_ID_V, size_2d) | |
| else: | |
| global_K, global_V, _, global_ID_V = long_term_memory | |
| local_K, local_V, _, local_ID_V = short_term_memory | |
| cat_global_V = torch.cat([global_V, global_ID_V], dim=-1) | |
| cat_local_V = torch.cat([local_V, local_ID_V], dim=1) | |
| cat_tgt2, _ = self.long_term_attn(curr_Q, global_K, cat_global_V, | |
| cat_curr_U, size_2d) | |
| cat_tgt3, _ = self.short_term_attn(local_Q, local_K, cat_local_V, | |
| cat_curr_U, size_2d) | |
| tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) | |
| tgt3, tgt_id3 = torch.split(cat_tgt3, self.d_model, dim=-1) | |
| if self.droppath_lst: | |
| tgt = tgt + self.droppath(tgt2 + tgt3) | |
| tgt_id = tgt_id + self.droppath(tgt_id2 + tgt_id3) | |
| else: | |
| tgt = tgt + self.lst_dropout(tgt2 + tgt3) | |
| tgt_id = tgt_id + self.lst_dropout(tgt_id2 + tgt_id3) | |
| # Self-attention | |
| _tgt = self.norm2(tgt) | |
| _tgt_id = self.id_norm2(tgt_id) | |
| q = k = v = u = torch.cat([_tgt, _tgt_id], dim=-1) | |
| cat_tgt2, _ = self.self_attn(q, k, v, u, size_2d) | |
| tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) | |
| tgt = tgt + self.droppath(tgt2) | |
| tgt_id = tgt_id + self.droppath(tgt_id2) | |
| return tgt, tgt_id, [[curr_K, curr_V, None, curr_ID_V], | |
| [global_K, global_V, None, global_ID_V], | |
| [local_K, local_V, None, local_ID_V]] | |
| def fuse_key_value_id(self, key, value, id_emb): | |
| ID_K = None | |
| if value is not None: | |
| ID_V = silu(self.linear_ID_V(torch.cat([value, id_emb], dim=2))) | |
| else: | |
| ID_V = silu(self.linear_ID_V(id_emb)) | |
| return ID_K, ID_V | |
| def _init_weight(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |