Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# app.py — 译文情感 VAD
|
| 2 |
import os, json, math
|
| 3 |
from typing import Dict, Tuple
|
| 4 |
|
|
@@ -8,9 +8,9 @@ from requests.adapters import HTTPAdapter
|
|
| 8 |
from urllib3.util.retry import Retry
|
| 9 |
|
| 10 |
APP_TITLE = "译文情感 VAD 对比"
|
| 11 |
-
APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D
|
| 12 |
|
| 13 |
-
# ===== 环境与端点(Router
|
| 14 |
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
|
| 15 |
MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip()
|
| 16 |
HF_API_URL = "https://router.huggingface.co/hf-inference"
|
|
@@ -26,13 +26,12 @@ _session.headers.update({"Connection": "close"})
|
|
| 26 |
def _trim(s: str, n: int = 2000) -> str:
|
| 27 |
return (s or "")[:n]
|
| 28 |
|
| 29 |
-
def
|
| 30 |
return max(0.0, min(1.0, float(x)))
|
| 31 |
|
| 32 |
-
# ===== 简易VAD
|
| 33 |
_POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"]
|
| 34 |
_NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"]
|
| 35 |
-
|
| 36 |
def simple_vad(text: str) -> Dict[str, float]:
|
| 37 |
t = text or ""
|
| 38 |
n = max(1, len(t))
|
|
@@ -45,9 +44,9 @@ def simple_vad(text: str) -> Dict[str, float]:
|
|
| 45 |
v = 0.5 + 0.12*(pos - neg) - 0.05*q
|
| 46 |
a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1))
|
| 47 |
d = 0.4 + 0.4*(len(set(t)) / n)
|
| 48 |
-
return {"valence":
|
| 49 |
|
| 50 |
-
# ===== 解析
|
| 51 |
def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
|
| 52 |
if isinstance(obj, dict):
|
| 53 |
k = {kk.lower(): vv for kk, vv in obj.items()}
|
|
@@ -73,7 +72,7 @@ def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
|
|
| 73 |
return m["valence"], m["arousal"], m["dominance"]
|
| 74 |
raise ValueError("无法从模型返回中解析 V/A/D")
|
| 75 |
|
| 76 |
-
# =====
|
| 77 |
def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
|
| 78 |
if not HF_TOKEN:
|
| 79 |
raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。")
|
|
@@ -86,35 +85,72 @@ def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
|
|
| 86 |
}
|
| 87 |
r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout))
|
| 88 |
if r.status_code == 404:
|
| 89 |
-
raise gr.Error("Router 404:该模型未由任何 Inference Provider
|
| 90 |
if r.status_code == 503:
|
| 91 |
raise gr.Error("模型冷启动(503)。稍后重试。")
|
| 92 |
if r.status_code >= 400:
|
| 93 |
raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}")
|
| 94 |
data = r.json()
|
| 95 |
v, a, d = _parse_vad_from_hf(data)
|
| 96 |
-
return {"valence":
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
# ===== 本地 Transformers 推理(CPU)=====
|
| 99 |
-
_local = {"tok": None, "model": None}
|
| 100 |
def _ensure_local():
|
| 101 |
-
if _local["tok"] is
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def local_vad(text: str) -> Dict[str, float]:
|
| 109 |
_ensure_local()
|
| 110 |
import torch
|
| 111 |
-
|
|
|
|
| 112 |
with torch.no_grad():
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
#
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# ===== 指标与可视化 =====
|
| 120 |
def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]:
|
|
@@ -156,14 +192,14 @@ def bar_html(s: Dict[str, float], r: Dict[str, float]) -> str:
|
|
| 156 |
<span class="swatch" style="background:#ffd6cc"></span>参考译文</div>
|
| 157 |
</div>"""
|
| 158 |
|
| 159 |
-
# =====
|
| 160 |
def run(student_text: str, reference_text: str, backend: str):
|
| 161 |
if not (student_text.strip() or reference_text.strip()):
|
| 162 |
raise gr.Error("请至少输入一段文本。")
|
| 163 |
-
if backend == "
|
| 164 |
-
s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "")
|
| 165 |
-
elif backend == "本地VAD(Transformers, CPU)":
|
| 166 |
s = local_vad(student_text or ""); r = local_vad(reference_text or "")
|
|
|
|
|
|
|
| 167 |
else:
|
| 168 |
s = simple_vad(student_text or ""); r = simple_vad(reference_text or "")
|
| 169 |
m = metrics(s, r)
|
|
@@ -173,7 +209,7 @@ def run(student_text: str, reference_text: str, backend: str):
|
|
| 173 |
f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}")
|
| 174 |
return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2)
|
| 175 |
|
| 176 |
-
def
|
| 177 |
if not HF_TOKEN:
|
| 178 |
return "未检测到 HF_TOKEN", ""
|
| 179 |
try:
|
|
@@ -193,12 +229,12 @@ with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}"
|
|
| 193 |
student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10)
|
| 194 |
reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10)
|
| 195 |
backend = gr.Radio(
|
| 196 |
-
["
|
| 197 |
value="本地VAD(Transformers, CPU)",
|
| 198 |
label="分析后端",
|
| 199 |
)
|
| 200 |
run_btn = gr.Button("运行对比")
|
| 201 |
-
gr.Markdown("###
|
| 202 |
chk_btn = gr.Button("测试 HF Router")
|
| 203 |
api_status = gr.Textbox(label="接口状态", lines=1)
|
| 204 |
api_body = gr.Textbox(label="返回片段", lines=5)
|
|
@@ -208,7 +244,7 @@ with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}"
|
|
| 208 |
raw_json = gr.Code(label="JSON 输出", language="json")
|
| 209 |
|
| 210 |
run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4)
|
| 211 |
-
chk_btn.click(
|
| 212 |
|
| 213 |
demo.queue()
|
| 214 |
app = demo
|
|
|
|
| 1 |
+
# app.py — 译文情感 VAD 对比(本地Transformers稳健版 + Router可选 + 简易兜底)
|
| 2 |
import os, json, math
|
| 3 |
from typing import Dict, Tuple
|
| 4 |
|
|
|
|
| 8 |
from urllib3.util.retry import Retry
|
| 9 |
|
| 10 |
APP_TITLE = "译文情感 VAD 对比"
|
| 11 |
+
APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 与差异。默认跑本地Transformers;Router仅在目标模型被Provider托管时可用。"
|
| 12 |
|
| 13 |
+
# ===== 环境与端点(Router,可选)=====
|
| 14 |
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
|
| 15 |
MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip()
|
| 16 |
HF_API_URL = "https://router.huggingface.co/hf-inference"
|
|
|
|
| 26 |
def _trim(s: str, n: int = 2000) -> str:
|
| 27 |
return (s or "")[:n]
|
| 28 |
|
| 29 |
+
def _c01(x: float) -> float:
|
| 30 |
return max(0.0, min(1.0, float(x)))
|
| 31 |
|
| 32 |
+
# ===== 简易VAD(兜底)=====
|
| 33 |
_POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"]
|
| 34 |
_NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"]
|
|
|
|
| 35 |
def simple_vad(text: str) -> Dict[str, float]:
|
| 36 |
t = text or ""
|
| 37 |
n = max(1, len(t))
|
|
|
|
| 44 |
v = 0.5 + 0.12*(pos - neg) - 0.05*q
|
| 45 |
a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1))
|
| 46 |
d = 0.4 + 0.4*(len(set(t)) / n)
|
| 47 |
+
return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}
|
| 48 |
|
| 49 |
+
# ===== 解析 VAD 结构(用于Router返回)=====
|
| 50 |
def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
|
| 51 |
if isinstance(obj, dict):
|
| 52 |
k = {kk.lower(): vv for kk, vv in obj.items()}
|
|
|
|
| 72 |
return m["valence"], m["arousal"], m["dominance"]
|
| 73 |
raise ValueError("无法从模型返回中解析 V/A/D")
|
| 74 |
|
| 75 |
+
# ===== Router 推理(仅当该模型被 Provider 托管时可用)=====
|
| 76 |
def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
|
| 77 |
if not HF_TOKEN:
|
| 78 |
raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。")
|
|
|
|
| 85 |
}
|
| 86 |
r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout))
|
| 87 |
if r.status_code == 404:
|
| 88 |
+
raise gr.Error("Router 404:该模型未由任何 Inference Provider 托管。改用“本地VAD”或换模型。")
|
| 89 |
if r.status_code == 503:
|
| 90 |
raise gr.Error("模型冷启动(503)。稍后重试。")
|
| 91 |
if r.status_code >= 400:
|
| 92 |
raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}")
|
| 93 |
data = r.json()
|
| 94 |
v, a, d = _parse_vad_from_hf(data)
|
| 95 |
+
return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}
|
| 96 |
+
|
| 97 |
+
# ===== 本地 Transformers(稳健适配)=====
|
| 98 |
+
_local = {"tok": None, "model": None, "cfg": None}
|
| 99 |
|
|
|
|
|
|
|
| 100 |
def _ensure_local():
|
| 101 |
+
if _local["tok"] is not None:
|
| 102 |
+
return
|
| 103 |
+
# 延迟导入,减少启动时间
|
| 104 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
|
| 105 |
+
_local["cfg"] = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 106 |
+
_local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
|
| 107 |
+
_local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 108 |
+
_local["model"].eval()
|
| 109 |
+
|
| 110 |
+
def _sig(x: float) -> float:
|
| 111 |
+
return 1.0 / (1.0 + math.exp(-x))
|
| 112 |
|
| 113 |
def local_vad(text: str) -> Dict[str, float]:
|
| 114 |
_ensure_local()
|
| 115 |
import torch
|
| 116 |
+
s = _trim(text, 512)
|
| 117 |
+
inputs = _local["tok"](s, return_tensors="pt", truncation=True, max_length=256)
|
| 118 |
with torch.no_grad():
|
| 119 |
+
out = _local["model"](**inputs)
|
| 120 |
+
|
| 121 |
+
# 1) 标准分类输出
|
| 122 |
+
if hasattr(out, "logits"):
|
| 123 |
+
logits = out.logits.squeeze()
|
| 124 |
+
# 1a) 有 id2label 且包含 V/A/D
|
| 125 |
+
id2label = getattr(_local["cfg"], "id2label", None)
|
| 126 |
+
if id2label and isinstance(id2label, dict):
|
| 127 |
+
lab = {int(k): str(v).lower() for k, v in id2label.items()}
|
| 128 |
+
scores = logits.tolist() if hasattr(logits, "tolist") else list(logits)
|
| 129 |
+
m = {}
|
| 130 |
+
for i, sc in enumerate(scores):
|
| 131 |
+
name = lab.get(i, "")
|
| 132 |
+
if "valence" in name or name == "v": m["valence"] = float(sc)
|
| 133 |
+
if "arousal" in name or name == "a": m["arousal"] = float(sc)
|
| 134 |
+
if "dominance" in name or name == "d": m["dominance"] = float(sc)
|
| 135 |
+
if len(m) == 3:
|
| 136 |
+
return {"valence": _c01(_sig(m["valence"])), "arousal": _c01(_sig(m["arousal"])), "dominance": _c01(_sig(m["dominance"]))}
|
| 137 |
+
# 1b) 无明确标签,但 num_labels>=3,取前三维
|
| 138 |
+
if logits.numel() >= 3:
|
| 139 |
+
v, a, d = [float(logits[i].item()) for i in range(3)]
|
| 140 |
+
return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}
|
| 141 |
+
|
| 142 |
+
# 2) 某些自定义模型可能把 VAD 放在 out.vad 或 out[...]
|
| 143 |
+
for key in ("vad", "scores", "preds"):
|
| 144 |
+
if hasattr(out, key):
|
| 145 |
+
vec = getattr(out, key)
|
| 146 |
+
try:
|
| 147 |
+
vec = list(vec)[:3]
|
| 148 |
+
v, a, d = float(vec[0]), float(vec[1]), float(vec[2])
|
| 149 |
+
return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}
|
| 150 |
+
except Exception:
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
raise gr.Error("本地VAD解析失败:模型输出不含可识别的 V/A/D 三维。请换兼容模型或改用简易VAD。")
|
| 154 |
|
| 155 |
# ===== 指标与可视化 =====
|
| 156 |
def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]:
|
|
|
|
| 192 |
<span class="swatch" style="background:#ffd6cc"></span>参考译文</div>
|
| 193 |
</div>"""
|
| 194 |
|
| 195 |
+
# ===== 主流程与诊断 =====
|
| 196 |
def run(student_text: str, reference_text: str, backend: str):
|
| 197 |
if not (student_text.strip() or reference_text.strip()):
|
| 198 |
raise gr.Error("请至少输入一段文本。")
|
| 199 |
+
if backend == "本地VAD(Transformers, CPU)":
|
|
|
|
|
|
|
| 200 |
s = local_vad(student_text or ""); r = local_vad(reference_text or "")
|
| 201 |
+
elif backend == "HF Router(服务端推理)":
|
| 202 |
+
s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "")
|
| 203 |
else:
|
| 204 |
s = simple_vad(student_text or ""); r = simple_vad(reference_text or "")
|
| 205 |
m = metrics(s, r)
|
|
|
|
| 209 |
f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}")
|
| 210 |
return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2)
|
| 211 |
|
| 212 |
+
def diagnose_router():
|
| 213 |
if not HF_TOKEN:
|
| 214 |
return "未检测到 HF_TOKEN", ""
|
| 215 |
try:
|
|
|
|
| 229 |
student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10)
|
| 230 |
reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10)
|
| 231 |
backend = gr.Radio(
|
| 232 |
+
["本地VAD(Transformers, CPU)", "HF Router(服务端推理)", "内置简易VAD(备用)"],
|
| 233 |
value="本地VAD(Transformers, CPU)",
|
| 234 |
label="分析后端",
|
| 235 |
)
|
| 236 |
run_btn = gr.Button("运行对比")
|
| 237 |
+
gr.Markdown("### 诊断(Router)")
|
| 238 |
chk_btn = gr.Button("测试 HF Router")
|
| 239 |
api_status = gr.Textbox(label="接口状态", lines=1)
|
| 240 |
api_body = gr.Textbox(label="返回片段", lines=5)
|
|
|
|
| 244 |
raw_json = gr.Code(label="JSON 输出", language="json")
|
| 245 |
|
| 246 |
run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4)
|
| 247 |
+
chk_btn.click(diagnose_router, [], [api_status, api_body], concurrency_limit=2)
|
| 248 |
|
| 249 |
demo.queue()
|
| 250 |
app = demo
|