lvyufeng commited on
Commit
4b90c4f
·
verified ·
1 Parent(s): 586c396

add model.combine_moe() to speedup

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr.py +38 -0
modeling_deepseekocr.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import math
3
  import re
 
 
4
  from tqdm import tqdm
5
  from abc import ABC
6
  from typing import List, Optional, Tuple, Union
@@ -15,6 +17,7 @@ from torch.nn import CrossEntropyLoss
15
  from torchvision import transforms
16
 
17
  from transformers.cache_utils import Cache
 
18
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
  from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
20
  from transformers import DeepseekV2Config
@@ -1058,3 +1061,38 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
1058
  plt.close()
1059
 
1060
  result.save(f"{output_path}/result_with_boxes.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import math
3
  import re
4
+ import gc
5
+ import types
6
  from tqdm import tqdm
7
  from abc import ABC
8
  from typing import List, Optional, Tuple, Union
 
17
  from torchvision import transforms
18
 
19
  from transformers.cache_utils import Cache
20
+ from transformers.activations import ACT2FN
21
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
22
  from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
23
  from transformers import DeepseekV2Config
 
1061
  plt.close()
1062
 
1063
  result.save(f"{output_path}/result_with_boxes.jpg")
1064
+
1065
+ def combine_moe(self):
1066
+ for layer in self.model.layers:
1067
+ if isinstance(layer.mlp, DeepseekV2MoE):
1068
+ moe_layer = layer.mlp
1069
+ # combine experts
1070
+ moe_layer.w1 = nn.Parameter(torch.stack([moe_layer.experts[i].gate_proj.weight.T for i in range(moe_layer.config.n_routed_experts)]), requires_grad=False)
1071
+ moe_layer.w2 = nn.Parameter(torch.stack([moe_layer.experts[i].down_proj.weight.T for i in range(moe_layer.config.n_routed_experts)]), requires_grad=False)
1072
+ moe_layer.w3 = nn.Parameter(torch.stack([moe_layer.experts[i].up_proj.weight.T for i in range(moe_layer.config.n_routed_experts)]), requires_grad=False)
1073
+ del moe_layer.experts
1074
+ gc.collect()
1075
+ moe_layer.experts = None
1076
+ moe_layer.act = ACT2FN[moe_layer.config.hidden_act]
1077
+ moe_layer.forward = types.MethodType(new_forward_for_moe, moe_layer)
1078
+
1079
+
1080
+ def new_forward_for_moe(self, hidden_states):
1081
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
1082
+ selected_experts, routing_weights = self.gate(hidden_states)
1083
+ router_scores = torch.zeros(size=(batch_size * sequence_length, self.config.n_routed_experts), device=hidden_states.device, dtype=hidden_states.dtype)
1084
+ # we cast back to the input dtype
1085
+ routing_weights = routing_weights.to(hidden_states.dtype)
1086
+ router_scores = torch.scatter_add(router_scores, -1, selected_experts, routing_weights)
1087
+ hidden_states = hidden_states.view(-1, hidden_dim)
1088
+ if self.config.n_shared_experts is not None:
1089
+ shared_expert_output = self.shared_experts(hidden_states)
1090
+
1091
+ hidden_w1 = torch.matmul(hidden_states, self.w1)
1092
+ hidden_w3 = torch.matmul(hidden_states, self.w3)
1093
+ hidden_states = self.act(hidden_w1) * hidden_w3
1094
+ hidden_states = torch.bmm(hidden_states, self.w2) * torch.transpose(router_scores, 0, 1).unsqueeze(-1)
1095
+ final_hidden_states = hidden_states.sum(dim=0, dtype=hidden_states.dtype)
1096
+ if self.config.n_shared_experts is not None:
1097
+ hidden_states = final_hidden_states + shared_expert_output
1098
+ return hidden_states.view(batch_size, sequence_length, hidden_dim)