add model.combine_moe() to speedup
Browse files- modeling_deepseekocr.py +38 -0
modeling_deepseekocr.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import os
|
| 2 |
import math
|
| 3 |
import re
|
|
|
|
|
|
|
| 4 |
from tqdm import tqdm
|
| 5 |
from abc import ABC
|
| 6 |
from typing import List, Optional, Tuple, Union
|
|
@@ -15,6 +17,7 @@ from torch.nn import CrossEntropyLoss
|
|
| 15 |
from torchvision import transforms
|
| 16 |
|
| 17 |
from transformers.cache_utils import Cache
|
|
|
|
| 18 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 19 |
from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
|
| 20 |
from transformers import DeepseekV2Config
|
|
@@ -1058,3 +1061,38 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 1058 |
plt.close()
|
| 1059 |
|
| 1060 |
result.save(f"{output_path}/result_with_boxes.jpg")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import math
|
| 3 |
import re
|
| 4 |
+
import gc
|
| 5 |
+
import types
|
| 6 |
from tqdm import tqdm
|
| 7 |
from abc import ABC
|
| 8 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 17 |
from torchvision import transforms
|
| 18 |
|
| 19 |
from transformers.cache_utils import Cache
|
| 20 |
+
from transformers.activations import ACT2FN
|
| 21 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 22 |
from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
|
| 23 |
from transformers import DeepseekV2Config
|
|
|
|
| 1061 |
plt.close()
|
| 1062 |
|
| 1063 |
result.save(f"{output_path}/result_with_boxes.jpg")
|
| 1064 |
+
|
| 1065 |
+
def combine_moe(self):
|
| 1066 |
+
for layer in self.model.layers:
|
| 1067 |
+
if isinstance(layer.mlp, DeepseekV2MoE):
|
| 1068 |
+
moe_layer = layer.mlp
|
| 1069 |
+
# combine experts
|
| 1070 |
+
moe_layer.w1 = nn.Parameter(torch.stack([moe_layer.experts[i].gate_proj.weight.T for i in range(moe_layer.config.n_routed_experts)]), requires_grad=False)
|
| 1071 |
+
moe_layer.w2 = nn.Parameter(torch.stack([moe_layer.experts[i].down_proj.weight.T for i in range(moe_layer.config.n_routed_experts)]), requires_grad=False)
|
| 1072 |
+
moe_layer.w3 = nn.Parameter(torch.stack([moe_layer.experts[i].up_proj.weight.T for i in range(moe_layer.config.n_routed_experts)]), requires_grad=False)
|
| 1073 |
+
del moe_layer.experts
|
| 1074 |
+
gc.collect()
|
| 1075 |
+
moe_layer.experts = None
|
| 1076 |
+
moe_layer.act = ACT2FN[moe_layer.config.hidden_act]
|
| 1077 |
+
moe_layer.forward = types.MethodType(new_forward_for_moe, moe_layer)
|
| 1078 |
+
|
| 1079 |
+
|
| 1080 |
+
def new_forward_for_moe(self, hidden_states):
|
| 1081 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 1082 |
+
selected_experts, routing_weights = self.gate(hidden_states)
|
| 1083 |
+
router_scores = torch.zeros(size=(batch_size * sequence_length, self.config.n_routed_experts), device=hidden_states.device, dtype=hidden_states.dtype)
|
| 1084 |
+
# we cast back to the input dtype
|
| 1085 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 1086 |
+
router_scores = torch.scatter_add(router_scores, -1, selected_experts, routing_weights)
|
| 1087 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 1088 |
+
if self.config.n_shared_experts is not None:
|
| 1089 |
+
shared_expert_output = self.shared_experts(hidden_states)
|
| 1090 |
+
|
| 1091 |
+
hidden_w1 = torch.matmul(hidden_states, self.w1)
|
| 1092 |
+
hidden_w3 = torch.matmul(hidden_states, self.w3)
|
| 1093 |
+
hidden_states = self.act(hidden_w1) * hidden_w3
|
| 1094 |
+
hidden_states = torch.bmm(hidden_states, self.w2) * torch.transpose(router_scores, 0, 1).unsqueeze(-1)
|
| 1095 |
+
final_hidden_states = hidden_states.sum(dim=0, dtype=hidden_states.dtype)
|
| 1096 |
+
if self.config.n_shared_experts is not None:
|
| 1097 |
+
hidden_states = final_hidden_states + shared_expert_output
|
| 1098 |
+
return hidden_states.view(batch_size, sequence_length, hidden_dim)
|