# app.py — 译文情感 VAD 对比(本地Transformers稳健版 + Router可选 + 简易兜底) import os, json, math from typing import Dict, Tuple import gradio as gr import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry APP_TITLE = "译文情感 VAD 对比" APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 与差异。默认跑本地Transformers;Router仅在目标模型被Provider托管时可用。" # ===== 环境与端点(Router,可选)===== HF_TOKEN = os.getenv("HF_TOKEN", "").strip() MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip() HF_API_URL = "https://router.huggingface.co/hf-inference" # ===== HTTP 会话(重试 + 连接关闭)===== _session = requests.Session() _session.mount("https://", HTTPAdapter(max_retries=Retry( total=3, connect=3, read=3, backoff_factor=0.5, status_forcelist=[502, 503, 504], allowed_methods=frozenset(["POST"]) ))) _session.headers.update({"Connection": "close"}) def _trim(s: str, n: int = 2000) -> str: return (s or "")[:n] def _c01(x: float) -> float: return max(0.0, min(1.0, float(x))) # ===== 简易VAD(兜底)===== _POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"] _NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"] def simple_vad(text: str) -> Dict[str, float]: t = text or "" n = max(1, len(t)) ex = t.count("!") + t.count("!") q = t.count("?") + t.count("?") caps = sum(c.isupper() for c in t) tl = t.lower() pos = sum(t.count(w) for w in _POS) + sum(tl.count(w) for w in _POS) neg = sum(t.count(w) for w in _NEG) + sum(tl.count(w) for w in _NEG) v = 0.5 + 0.12*(pos - neg) - 0.05*q a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1)) d = 0.4 + 0.4*(len(set(t)) / n) return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)} # ===== 解析 VAD 结构(用于Router返回)===== def _parse_vad_from_hf(obj) -> Tuple[float, float, float]: if isinstance(obj, dict): k = {kk.lower(): vv for kk, vv in obj.items()} if all(x in k for x in ("valence","arousal","dominance")): return float(k["valence"]), float(k["arousal"]), float(k["dominance"]) for key in ("embedding","vector","vad"): if key in k and isinstance(k[key], (list, tuple)) and len(k[key]) >= 3: return float(k[key][0]), float(k[key][1]), float(k[key][2]) if isinstance(obj, list) and len(obj) >= 3: if all(isinstance(x, (int, float)) for x in obj[:3]): return float(obj[0]), float(obj[1]), float(obj[2]) if all(isinstance(x, dict) for x in obj[:3]): m = {} for it in obj: lab = str(it.get("label","")).lower() sc = it.get("score", None) if sc is None: continue if "valence" in lab or lab == "v": m["valence"] = float(sc) elif "arousal" in lab or lab == "a": m["arousal"] = float(sc) elif "dominance" in lab or lab == "d": m["dominance"] = float(sc) if all(t in m for t in ("valence","arousal","dominance")): return m["valence"], m["arousal"], m["dominance"] raise ValueError("无法从模型返回中解析 V/A/D") # ===== Router 推理(仅当该模型被 Provider 托管时可用)===== def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]: if not HF_TOKEN: raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。") payload = {"model": MODEL_ID, "inputs": _trim(text, 2000)} headers = { "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json", "Connection": "close", "X-Wait-For-Model": "true", } r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout)) if r.status_code == 404: raise gr.Error("Router 404:该模型未由任何 Inference Provider 托管。改用“本地VAD”或换模型。") if r.status_code == 503: raise gr.Error("模型冷启动(503)。稍后重试。") if r.status_code >= 400: raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}") data = r.json() v, a, d = _parse_vad_from_hf(data) return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)} # ===== 本地 Transformers(稳健适配)===== _local = {"tok": None, "model": None, "cfg": None} def _ensure_local(): if _local["tok"] is not None: return # 延迟导入,减少启动时间 from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification _local["cfg"] = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) _local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) _local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, trust_remote_code=True) _local["model"].eval() def _sig(x: float) -> float: return 1.0 / (1.0 + math.exp(-x)) def local_vad(text: str) -> Dict[str, float]: _ensure_local() import torch s = _trim(text, 512) inputs = _local["tok"](s, return_tensors="pt", truncation=True, max_length=256) with torch.no_grad(): out = _local["model"](**inputs) # 1) 标准分类输出 if hasattr(out, "logits"): logits = out.logits.squeeze() # 1a) 有 id2label 且包含 V/A/D id2label = getattr(_local["cfg"], "id2label", None) if id2label and isinstance(id2label, dict): lab = {int(k): str(v).lower() for k, v in id2label.items()} scores = logits.tolist() if hasattr(logits, "tolist") else list(logits) m = {} for i, sc in enumerate(scores): name = lab.get(i, "") if "valence" in name or name == "v": m["valence"] = float(sc) if "arousal" in name or name == "a": m["arousal"] = float(sc) if "dominance" in name or name == "d": m["dominance"] = float(sc) if len(m) == 3: return {"valence": _c01(_sig(m["valence"])), "arousal": _c01(_sig(m["arousal"])), "dominance": _c01(_sig(m["dominance"]))} # 1b) 无明确标签,但 num_labels>=3,取前三维 if logits.numel() >= 3: v, a, d = [float(logits[i].item()) for i in range(3)] return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))} # 2) 某些自定义模型可能把 VAD 放在 out.vad 或 out[...] for key in ("vad", "scores", "preds"): if hasattr(out, key): vec = getattr(out, key) try: vec = list(vec)[:3] v, a, d = float(vec[0]), float(vec[1]), float(vec[2]) return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))} except Exception: pass raise gr.Error("本地VAD解析失败:模型输出不含可识别的 V/A/D 三维。请换兼容模型或改用简易VAD。") # ===== 指标与可视化 ===== def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]: dv = v1["valence"] - v2["valence"] da = v1["arousal"] - v2["arousal"] dd = v1["dominance"] - v2["dominance"] l2 = math.sqrt(dv*dv + da*da + dd*dd) n1 = math.sqrt(v1["valence"]**2 + v1["arousal"]**2 + v1["dominance"]**2) n2 = math.sqrt(v2["valence"]**2 + v2["arousal"]**2 + v2["dominance"]**2) cos = (v1["valence"]*v2["valence"] + v1["arousal"]*v2["arousal"] + v1["dominance"]*v2["dominance"]) / (n1*n2) if n1>0 and n2>0 else 0.0 return {"ΔV": dv, "ΔA": da, "ΔD": dd, "L2_distance": l2, "cosine_similarity": cos} def bar_html(s: Dict[str, float], r: Dict[str, float]) -> str: def row(label, sv, rv): return f"""
{label}
学生 {sv:.3f}
参考 {rv:.3f}
""" css = """ """ return f"""{css}
{row("Valence", s['valence'], r['valence'])} {row("Arousal", s['arousal'], r['arousal'])} {row("Dominance", s['dominance'], r['dominance'])}
学生译文   参考译文
""" # ===== 主流程与诊断 ===== def run(student_text: str, reference_text: str, backend: str): if not (student_text.strip() or reference_text.strip()): raise gr.Error("请至少输入一段文本。") if backend == "本地VAD(Transformers, CPU)": s = local_vad(student_text or ""); r = local_vad(reference_text or "") elif backend == "HF Router(服务端推理)": s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "") else: s = simple_vad(student_text or ""); r = simple_vad(reference_text or "") m = metrics(s, r) rpt = (f"学生译文 VAD: V={s['valence']:.3f}, A={s['arousal']:.3f}, D={s['dominance']:.3f}\n" f"参考译文 VAD: V={r['valence']:.3f}, A={r['arousal']:.3f}, D={r['dominance']:.3f}\n" f"差异: ΔV={m['ΔV']:.3f}, ΔA={m['ΔA']:.3f}, ΔD={m['ΔD']:.3f}\n" f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}") return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2) def diagnose_router(): if not HF_TOKEN: return "未检测到 HF_TOKEN", "" try: res = _session.post(HF_API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}", "Content-Type":"application/json"}, json={"model": MODEL_ID, "inputs": "ok"}, timeout=(8, 30)) return f"HTTP {res.status_code}", res.text[:500] except Exception as e: return f"异常:{type(e).__name__}: {e}", "" # ===== 界面 ===== with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}") as demo: gr.Markdown(f"# {APP_TITLE}\n{APP_DESC}") with gr.Row(elem_classes=["wrap"]): with gr.Column(scale=5): student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10) reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10) backend = gr.Radio( ["本地VAD(Transformers, CPU)", "HF Router(服务端推理)", "内置简易VAD(备用)"], value="本地VAD(Transformers, CPU)", label="分析后端", ) run_btn = gr.Button("运行对比") gr.Markdown("### 诊断(Router)") chk_btn = gr.Button("测试 HF Router") api_status = gr.Textbox(label="接口状态", lines=1) api_body = gr.Textbox(label="返回片段", lines=5) with gr.Column(scale=5): chart = gr.HTML(label="VAD 对比柱状图") report = gr.Textbox(label="摘要结果", lines=4) raw_json = gr.Code(label="JSON 输出", language="json") run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4) chk_btn.click(diagnose_router, [], [api_status, api_body], concurrency_limit=2) demo.queue() app = demo if __name__ == "__main__": demo.launch()