LokeZhou commited on
Commit
cf5d8cf
·
1 Parent(s): 9c6c306

ERNIE-4.5-VL-28B-A3B-Thinking demo

Browse files
Files changed (1) hide show
  1. app.py +224 -4
app.py CHANGED
@@ -1,7 +1,227 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoProcessor,TextStreamer,TextIteratorStreamer
4
+ from PIL import Image
5
+ import base64
6
+ import io
7
+ import re
8
+ from typing import Generator, List, Tuple, Optional
9
+ import threading
10
 
11
+ MAX_HISTORY=5
12
+ model_path = 'baidu/ERNIE-4.5-VL-28B-A3B-Thinking'
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_path,
15
+ device_map="auto",
16
+ torch_dtype=torch.bfloat16,
17
+ trust_remote_code=True
18
+ )
19
 
20
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
21
+ processor.eval()
22
+ model.add_image_preprocess(processor)
23
+
24
+
25
+ def encode_image(image: Image.Image) -> str:
26
+ if image is None:
27
+ return ""
28
+ buffer = io.BytesIO()
29
+ image.save(buffer, format="PNG")
30
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
31
+
32
+ def extract_text_from_html(html: str) -> str:
33
+ text = re.sub(r'<img.*?>', '', html)
34
+ text = re.sub(r'<.*?>', '', text)
35
+ if text.startswith("user: "):
36
+ return text[6:].strip()
37
+ elif text.startswith("assistant: "):
38
+ return text[8:].strip()
39
+ return text.strip()
40
+
41
+ def process_chat(
42
+ message: str,
43
+ image: Optional[Image.Image],
44
+ chat_history: List[Tuple[str, str, Optional[str]]],
45
+ max_new_tokens: int,
46
+ temperature: float
47
+ ) -> Generator[List[Tuple[str, str]], None, None]:
48
+ """处理聊天输入,流式生成回应"""
49
+
50
+ current_image_b64 = encode_image(image) if image else None
51
+ image_html = ""
52
+ if current_image_b64:
53
+ image_html = f'<br><img src="data:image/png;base64,{current_image_b64}" style="max-width:300px; border-radius:4px;">'
54
+
55
+
56
+ user_text = message
57
+ user_message_html = f"user: {user_text}{image_html}"
58
+
59
+ temp_history = chat_history + [(user_message_html, "", current_image_b64)]
60
+
61
+
62
+ model_messages = []
63
+
64
+
65
+ for hist in temp_history[:-1]:
66
+ user_html, assistant_text, hist_image_b64 = hist
67
+ user_text_clean = extract_text_from_html(user_html)
68
+
69
+ user_content = [{"type": "text", "text": user_text_clean}]
70
+ if hist_image_b64:
71
+ user_content.insert(0, {"type": "image_url","image_url": {"url": hist_image_b64}})
72
+
73
+ model_messages.append({"role": "user", "content": user_content})
74
+ assistant_content=[{"type": "text", "text": assistant_text}]
75
+ model_messages.append({"role": "bot", "content": assistant_content})
76
+
77
+
78
+ current_user_content = [{"type": "text", "text": user_text}]
79
+
80
+ if current_image_b64:
81
+ current_user_content.insert(0, {"type": "image_url", "image_url": {"url":current_image_b64}})
82
+
83
+ model_messages.append({"role": "user", "content": current_user_content})
84
+
85
+ text = processor.tokenizer.apply_chat_template(
86
+ model_messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
87
+ )
88
+
89
+
90
+ image_inputs, video_inputs = processor.process_vision_info(model_messages)
91
+
92
+ inputs = processor(
93
+ text=[text],
94
+ images=image_inputs,
95
+ videos=video_inputs,
96
+ padding=True,
97
+ return_tensors="pt",
98
+ )
99
+
100
+ device = next(model.parameters()).device
101
+ inputs = inputs.to(device)
102
+
103
+
104
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
105
+ generation_kwargs = {
106
+ **inputs,
107
+ "streamer": streamer,
108
+ "max_new_tokens": max_new_tokens,
109
+ "temperature": temperature,
110
+ "use_cache": False
111
+ }
112
+
113
+
114
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
115
+ thread.start()
116
+
117
+
118
+ generated_text = ""
119
+ for new_token in streamer:
120
+ generated_text += new_token
121
+
122
+ temp_history[-1] = (user_message_html, f"assistant: {generated_text}", current_image_b64)
123
+
124
+ display_history = [(h[0], h[1]) for h in temp_history[-MAX_HISTORY:]]
125
+ yield display_history
126
+
127
+ thread.join()
128
+
129
+
130
+ def chat_interface(
131
+ message: str,
132
+ image: Optional[Image.Image],
133
+ chat_history: List[Tuple[str, str, Optional[str]]],
134
+ max_new_tokens: int,
135
+ temperature: float
136
+ ) -> Generator[tuple, None, None]:
137
+
138
+
139
+ for updated_display_history in process_chat(message, image, chat_history, max_new_tokens, temperature):
140
+
141
+ updated_full_history = []
142
+ for i, display_item in enumerate(updated_display_history):
143
+
144
+ full_item = next((h for h in chat_history if h[0] == display_item[0] and h[1] == display_item[1]), None)
145
+ if full_item:
146
+ updated_full_history.append(full_item)
147
+ else:
148
+
149
+ if i == len(updated_display_history) - 1:
150
+
151
+ img_b64 = encode_image(image) if image else None
152
+ updated_full_history.append((display_item[0], display_item[1], img_b64))
153
+ else:
154
+ updated_full_history.append((display_item[0], display_item[1], None))
155
+
156
+ yield "", None, updated_full_history, updated_display_history
157
+
158
+
159
+ with gr.Blocks(title="ERNIE-4.5-VL-28B-A3B-Thinking", theme=gr.themes.Soft()) as demo:
160
+
161
+
162
+ full_chat_history = gr.State([])
163
+
164
+ with gr.Row():
165
+ with gr.Column(scale=3):
166
+
167
+ chat_display = gr.Chatbot(
168
+ label="chat_bot",
169
+ height=500,
170
+ bubble_full_width=False
171
+ )
172
+
173
+ with gr.Column(scale=1):
174
+
175
+ gr.Markdown("generation kwargs")
176
+ max_new_tokens = gr.Slider(
177
+ minimum=64, maximum=2048, value=512, step=64,
178
+ label="max_new_token"
179
+ )
180
+ temperature = gr.Slider(
181
+ minimum=0.1, maximum=2.0, value=0.7, step=0.1,
182
+ label="temperature"
183
+ )
184
+ clear_btn = gr.Button("clear", variant="destructive")
185
+
186
+ with gr.Row():
187
+
188
+ text_input = gr.Textbox(
189
+ label="input text",
190
+ placeholder="input text messages...",
191
+ lines=2
192
+ )
193
+ image_input = gr.Image(
194
+ label="input image",
195
+ placeholder="upload image...",
196
+ type="pil",
197
+ height=100
198
+ )
199
+ submit_btn = gr.Button("submit", variant="primary")
200
+
201
+
202
+ submit_btn.click(
203
+ fn=chat_interface,
204
+ inputs=[text_input, image_input, full_chat_history, max_new_tokens, temperature],
205
+ outputs=[text_input, image_input, full_chat_history, chat_display]
206
+ )
207
+
208
+
209
+ text_input.submit(
210
+ fn=chat_interface,
211
+ inputs=[text_input, image_input, full_chat_history, max_new_tokens, temperature],
212
+ outputs=[text_input, image_input, full_chat_history, chat_display]
213
+ )
214
+
215
+
216
+ def clear_chat():
217
+ return [], []
218
+
219
+ clear_btn.click(
220
+ fn=clear_chat,
221
+ inputs=[],
222
+ outputs=[full_chat_history, chat_display]
223
+ )
224
+
225
+
226
+ if __name__ == "__main__":
227
+ demo.launch(server_name="0.0.0.0", server_port=8100,share=False)