Shirosawa commited on
Commit
75f43fb
·
verified ·
1 Parent(s): 4f92415

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -33
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — 译文情感 VAD 对比(Router / 本地Transformers / 简易兜底)
2
  import os, json, math
3
  from typing import Dict, Tuple
4
 
@@ -8,9 +8,9 @@ from requests.adapters import HTTPAdapter
8
  from urllib3.util.retry import Retry
9
 
10
  APP_TITLE = "译文情感 VAD 对比"
11
- APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 分数、差异与简易柱状图。可选:HF Router、本地Transformers、内置简易VAD。"
12
 
13
- # ===== 环境与端点(Router) =====
14
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
15
  MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip()
16
  HF_API_URL = "https://router.huggingface.co/hf-inference"
@@ -26,13 +26,12 @@ _session.headers.update({"Connection": "close"})
26
  def _trim(s: str, n: int = 2000) -> str:
27
  return (s or "")[:n]
28
 
29
- def _clamp01(x: float) -> float:
30
  return max(0.0, min(1.0, float(x)))
31
 
32
- # ===== 简易VAD(兜底,仅演示)=====
33
  _POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"]
34
  _NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"]
35
-
36
  def simple_vad(text: str) -> Dict[str, float]:
37
  t = text or ""
38
  n = max(1, len(t))
@@ -45,9 +44,9 @@ def simple_vad(text: str) -> Dict[str, float]:
45
  v = 0.5 + 0.12*(pos - neg) - 0.05*q
46
  a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1))
47
  d = 0.4 + 0.4*(len(set(t)) / n)
48
- return {"valence": _clamp01(v), "arousal": _clamp01(a), "dominance": _clamp01(d)}
49
 
50
- # ===== 解析 HF 返回 =====
51
  def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
52
  if isinstance(obj, dict):
53
  k = {kk.lower(): vv for kk, vv in obj.items()}
@@ -73,7 +72,7 @@ def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
73
  return m["valence"], m["arousal"], m["dominance"]
74
  raise ValueError("无法从模型返回中解析 V/A/D")
75
 
76
- # ===== HF Router 推理(若该模型被提供商部署才会成功)=====
77
  def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
78
  if not HF_TOKEN:
79
  raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。")
@@ -86,35 +85,72 @@ def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
86
  }
87
  r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout))
88
  if r.status_code == 404:
89
- raise gr.Error("Router 404:该模型未由任何 Inference Provider 部署。请改用“本地VAD”后端。")
90
  if r.status_code == 503:
91
  raise gr.Error("模型冷启动(503)。稍后重试。")
92
  if r.status_code >= 400:
93
  raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}")
94
  data = r.json()
95
  v, a, d = _parse_vad_from_hf(data)
96
- return {"valence": _clamp01(v), "arousal": _clamp01(a), "dominance": _clamp01(d)}
 
 
 
97
 
98
- # ===== 本地 Transformers 推理(CPU)=====
99
- _local = {"tok": None, "model": None}
100
  def _ensure_local():
101
- if _local["tok"] is None or _local["model"] is None:
102
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
103
- import torch # noqa: F401
104
- _local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID)
105
- _local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
106
- _local["model"].eval()
 
 
 
 
 
107
 
108
  def local_vad(text: str) -> Dict[str, float]:
109
  _ensure_local()
110
  import torch
111
- inputs = _local["tok"](_trim(text, 512), return_tensors="pt", truncation=True, max_length=256)
 
112
  with torch.no_grad():
113
- logits = _local["model"](**inputs).logits.squeeze().tolist()
114
- v, a, d = [float(x) for x in logits[:3]]
115
- # 统一到 [0,1] 便于可视化(模型原始输出为回归值)
116
- sig = lambda x: 1.0/(1.0+math.exp(-x))
117
- return {"valence": _clamp01(sig(v)), "arousal": _clamp01(sig(a)), "dominance": _clamp01(sig(d))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # ===== 指标与可视化 =====
120
  def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]:
@@ -156,14 +192,14 @@ def bar_html(s: Dict[str, float], r: Dict[str, float]) -> str:
156
  <span class="swatch" style="background:#ffd6cc"></span>参考译文</div>
157
  </div>"""
158
 
159
- # ===== 主流程与自检 =====
160
  def run(student_text: str, reference_text: str, backend: str):
161
  if not (student_text.strip() or reference_text.strip()):
162
  raise gr.Error("请至少输入一段文本。")
163
- if backend == "HF Router(服务端推理)":
164
- s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "")
165
- elif backend == "本地VAD(Transformers, CPU)":
166
  s = local_vad(student_text or ""); r = local_vad(reference_text or "")
 
 
167
  else:
168
  s = simple_vad(student_text or ""); r = simple_vad(reference_text or "")
169
  m = metrics(s, r)
@@ -173,7 +209,7 @@ def run(student_text: str, reference_text: str, backend: str):
173
  f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}")
174
  return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2)
175
 
176
- def diagnose():
177
  if not HF_TOKEN:
178
  return "未检测到 HF_TOKEN", ""
179
  try:
@@ -193,12 +229,12 @@ with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}"
193
  student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10)
194
  reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10)
195
  backend = gr.Radio(
196
- ["HF Router(服务端推理)", "本地VAD(Transformers, CPU)", "内置简易VAD(备用)"],
197
  value="本地VAD(Transformers, CPU)",
198
  label="分析后端",
199
  )
200
  run_btn = gr.Button("运行对比")
201
- gr.Markdown("### 自检(Router)")
202
  chk_btn = gr.Button("测试 HF Router")
203
  api_status = gr.Textbox(label="接口状态", lines=1)
204
  api_body = gr.Textbox(label="返回片段", lines=5)
@@ -208,7 +244,7 @@ with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}"
208
  raw_json = gr.Code(label="JSON 输出", language="json")
209
 
210
  run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4)
211
- chk_btn.click(diagnose, [], [api_status, api_body], concurrency_limit=2)
212
 
213
  demo.queue()
214
  app = demo
 
1
+ # app.py — 译文情感 VAD 对比(本地Transformers稳健版 + Router可选 + 简易兜底)
2
  import os, json, math
3
  from typing import Dict, Tuple
4
 
 
8
  from urllib3.util.retry import Retry
9
 
10
  APP_TITLE = "译文情感 VAD 对比"
11
+ APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 与差异。默认跑本地Transformers;Router仅在目标模型被Provider托管时可用。"
12
 
13
+ # ===== 环境与端点(Router,可选)=====
14
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
15
  MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip()
16
  HF_API_URL = "https://router.huggingface.co/hf-inference"
 
26
  def _trim(s: str, n: int = 2000) -> str:
27
  return (s or "")[:n]
28
 
29
+ def _c01(x: float) -> float:
30
  return max(0.0, min(1.0, float(x)))
31
 
32
+ # ===== 简易VAD(兜底)=====
33
  _POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"]
34
  _NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"]
 
35
  def simple_vad(text: str) -> Dict[str, float]:
36
  t = text or ""
37
  n = max(1, len(t))
 
44
  v = 0.5 + 0.12*(pos - neg) - 0.05*q
45
  a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1))
46
  d = 0.4 + 0.4*(len(set(t)) / n)
47
+ return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}
48
 
49
+ # ===== 解析 VAD 结构(用于Router返回)=====
50
  def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
51
  if isinstance(obj, dict):
52
  k = {kk.lower(): vv for kk, vv in obj.items()}
 
72
  return m["valence"], m["arousal"], m["dominance"]
73
  raise ValueError("无法从模型返回中解析 V/A/D")
74
 
75
+ # ===== Router 推理(仅当该模型被 Provider 托管时可用)=====
76
  def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
77
  if not HF_TOKEN:
78
  raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。")
 
85
  }
86
  r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout))
87
  if r.status_code == 404:
88
+ raise gr.Error("Router 404:该模型未由任何 Inference Provider 托管。改用“本地VAD”或换模型。")
89
  if r.status_code == 503:
90
  raise gr.Error("模型冷启动(503)。稍后重试。")
91
  if r.status_code >= 400:
92
  raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}")
93
  data = r.json()
94
  v, a, d = _parse_vad_from_hf(data)
95
+ return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}
96
+
97
+ # ===== 本地 Transformers(稳健适配)=====
98
+ _local = {"tok": None, "model": None, "cfg": None}
99
 
 
 
100
  def _ensure_local():
101
+ if _local["tok"] is not None:
102
+ return
103
+ # 延迟导入,减少启动时间
104
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
105
+ _local["cfg"] = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
106
+ _local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
107
+ _local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, trust_remote_code=True)
108
+ _local["model"].eval()
109
+
110
+ def _sig(x: float) -> float:
111
+ return 1.0 / (1.0 + math.exp(-x))
112
 
113
  def local_vad(text: str) -> Dict[str, float]:
114
  _ensure_local()
115
  import torch
116
+ s = _trim(text, 512)
117
+ inputs = _local["tok"](s, return_tensors="pt", truncation=True, max_length=256)
118
  with torch.no_grad():
119
+ out = _local["model"](**inputs)
120
+
121
+ # 1) 标准分类输出
122
+ if hasattr(out, "logits"):
123
+ logits = out.logits.squeeze()
124
+ # 1a) 有 id2label 且包含 V/A/D
125
+ id2label = getattr(_local["cfg"], "id2label", None)
126
+ if id2label and isinstance(id2label, dict):
127
+ lab = {int(k): str(v).lower() for k, v in id2label.items()}
128
+ scores = logits.tolist() if hasattr(logits, "tolist") else list(logits)
129
+ m = {}
130
+ for i, sc in enumerate(scores):
131
+ name = lab.get(i, "")
132
+ if "valence" in name or name == "v": m["valence"] = float(sc)
133
+ if "arousal" in name or name == "a": m["arousal"] = float(sc)
134
+ if "dominance" in name or name == "d": m["dominance"] = float(sc)
135
+ if len(m) == 3:
136
+ return {"valence": _c01(_sig(m["valence"])), "arousal": _c01(_sig(m["arousal"])), "dominance": _c01(_sig(m["dominance"]))}
137
+ # 1b) 无明确标签,但 num_labels>=3,取前三维
138
+ if logits.numel() >= 3:
139
+ v, a, d = [float(logits[i].item()) for i in range(3)]
140
+ return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}
141
+
142
+ # 2) 某些自定义模型可能把 VAD 放在 out.vad 或 out[...]
143
+ for key in ("vad", "scores", "preds"):
144
+ if hasattr(out, key):
145
+ vec = getattr(out, key)
146
+ try:
147
+ vec = list(vec)[:3]
148
+ v, a, d = float(vec[0]), float(vec[1]), float(vec[2])
149
+ return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}
150
+ except Exception:
151
+ pass
152
+
153
+ raise gr.Error("本地VAD解析失败:模型输出不含可识别的 V/A/D 三维。请换兼容模型或改用简易VAD。")
154
 
155
  # ===== 指标与可视化 =====
156
  def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]:
 
192
  <span class="swatch" style="background:#ffd6cc"></span>参考译文</div>
193
  </div>"""
194
 
195
+ # ===== 主流程与诊断 =====
196
  def run(student_text: str, reference_text: str, backend: str):
197
  if not (student_text.strip() or reference_text.strip()):
198
  raise gr.Error("请至少输入一段文本。")
199
+ if backend == "本地VAD(Transformers, CPU)":
 
 
200
  s = local_vad(student_text or ""); r = local_vad(reference_text or "")
201
+ elif backend == "HF Router(服务端推理)":
202
+ s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "")
203
  else:
204
  s = simple_vad(student_text or ""); r = simple_vad(reference_text or "")
205
  m = metrics(s, r)
 
209
  f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}")
210
  return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2)
211
 
212
+ def diagnose_router():
213
  if not HF_TOKEN:
214
  return "未检测到 HF_TOKEN", ""
215
  try:
 
229
  student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10)
230
  reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10)
231
  backend = gr.Radio(
232
+ ["本地VAD(Transformers, CPU)", "HF Router(服务端推理)", "内置简易VAD(备用)"],
233
  value="本地VAD(Transformers, CPU)",
234
  label="分析后端",
235
  )
236
  run_btn = gr.Button("运行对比")
237
+ gr.Markdown("### 诊断(Router)")
238
  chk_btn = gr.Button("测试 HF Router")
239
  api_status = gr.Textbox(label="接口状态", lines=1)
240
  api_body = gr.Textbox(label="返回片段", lines=5)
 
244
  raw_json = gr.Code(label="JSON 输出", language="json")
245
 
246
  run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4)
247
+ chk_btn.click(diagnose_router, [], [api_status, api_body], concurrency_limit=2)
248
 
249
  demo.queue()
250
  app = demo