# app.py — 译文情感 VAD 对比(本地Transformers稳健版 + Router可选 + 简易兜底) import os, json, math from typing import Dict, Tuple import gradio as gr import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry APP_TITLE = "译文情感 VAD 对比" APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 与差异。默认跑本地Transformers;Router仅在目标模型被Provider托管时可用。" # ===== 环境与端点(Router,可选)===== HF_TOKEN = os.getenv("HF_TOKEN", "").strip() MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip() HF_API_URL = "https://router.huggingface.co/hf-inference" # ===== HTTP 会话(重试 + 连接关闭)===== _session = requests.Session() _session.mount("https://", HTTPAdapter(max_retries=Retry( total=3, connect=3, read=3, backoff_factor=0.5, status_forcelist=[502, 503, 504], allowed_methods=frozenset(["POST"]) ))) _session.headers.update({"Connection": "close"}) def _trim(s: str, n: int = 2000) -> str: return (s or "")[:n] def _c01(x: float) -> float: return max(0.0, min(1.0, float(x))) # ===== 简易VAD(兜底)===== _POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"] _NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"] def simple_vad(text: str) -> Dict[str, float]: t = text or "" n = max(1, len(t)) ex = t.count("!") + t.count("!") q = t.count("?") + t.count("?") caps = sum(c.isupper() for c in t) tl = t.lower() pos = sum(t.count(w) for w in _POS) + sum(tl.count(w) for w in _POS) neg = sum(t.count(w) for w in _NEG) + sum(tl.count(w) for w in _NEG) v = 0.5 + 0.12*(pos - neg) - 0.05*q a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1)) d = 0.4 + 0.4*(len(set(t)) / n) return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)} # ===== 解析 VAD 结构(用于Router返回)===== def _parse_vad_from_hf(obj) -> Tuple[float, float, float]: if isinstance(obj, dict): k = {kk.lower(): vv for kk, vv in obj.items()} if all(x in k for x in ("valence","arousal","dominance")): return float(k["valence"]), float(k["arousal"]), float(k["dominance"]) for key in ("embedding","vector","vad"): if key in k and isinstance(k[key], (list, tuple)) and len(k[key]) >= 3: return float(k[key][0]), float(k[key][1]), float(k[key][2]) if isinstance(obj, list) and len(obj) >= 3: if all(isinstance(x, (int, float)) for x in obj[:3]): return float(obj[0]), float(obj[1]), float(obj[2]) if all(isinstance(x, dict) for x in obj[:3]): m = {} for it in obj: lab = str(it.get("label","")).lower() sc = it.get("score", None) if sc is None: continue if "valence" in lab or lab == "v": m["valence"] = float(sc) elif "arousal" in lab or lab == "a": m["arousal"] = float(sc) elif "dominance" in lab or lab == "d": m["dominance"] = float(sc) if all(t in m for t in ("valence","arousal","dominance")): return m["valence"], m["arousal"], m["dominance"] raise ValueError("无法从模型返回中解析 V/A/D") # ===== Router 推理(仅当该模型被 Provider 托管时可用)===== def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]: if not HF_TOKEN: raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。") payload = {"model": MODEL_ID, "inputs": _trim(text, 2000)} headers = { "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json", "Connection": "close", "X-Wait-For-Model": "true", } r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout)) if r.status_code == 404: raise gr.Error("Router 404:该模型未由任何 Inference Provider 托管。改用“本地VAD”或换模型。") if r.status_code == 503: raise gr.Error("模型冷启动(503)。稍后重试。") if r.status_code >= 400: raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}") data = r.json() v, a, d = _parse_vad_from_hf(data) return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)} # ===== 本地 Transformers(稳健适配)===== _local = {"tok": None, "model": None, "cfg": None} def _ensure_local(): if _local["tok"] is not None: return # 延迟导入,减少启动时间 from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification _local["cfg"] = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) _local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) _local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, trust_remote_code=True) _local["model"].eval() def _sig(x: float) -> float: return 1.0 / (1.0 + math.exp(-x)) def local_vad(text: str) -> Dict[str, float]: _ensure_local() import torch s = _trim(text, 512) inputs = _local["tok"](s, return_tensors="pt", truncation=True, max_length=256) with torch.no_grad(): out = _local["model"](**inputs) # 1) 标准分类输出 if hasattr(out, "logits"): logits = out.logits.squeeze() # 1a) 有 id2label 且包含 V/A/D id2label = getattr(_local["cfg"], "id2label", None) if id2label and isinstance(id2label, dict): lab = {int(k): str(v).lower() for k, v in id2label.items()} scores = logits.tolist() if hasattr(logits, "tolist") else list(logits) m = {} for i, sc in enumerate(scores): name = lab.get(i, "") if "valence" in name or name == "v": m["valence"] = float(sc) if "arousal" in name or name == "a": m["arousal"] = float(sc) if "dominance" in name or name == "d": m["dominance"] = float(sc) if len(m) == 3: return {"valence": _c01(_sig(m["valence"])), "arousal": _c01(_sig(m["arousal"])), "dominance": _c01(_sig(m["dominance"]))} # 1b) 无明确标签,但 num_labels>=3,取前三维 if logits.numel() >= 3: v, a, d = [float(logits[i].item()) for i in range(3)] return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))} # 2) 某些自定义模型可能把 VAD 放在 out.vad 或 out[...] for key in ("vad", "scores", "preds"): if hasattr(out, key): vec = getattr(out, key) try: vec = list(vec)[:3] v, a, d = float(vec[0]), float(vec[1]), float(vec[2]) return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))} except Exception: pass raise gr.Error("本地VAD解析失败:模型输出不含可识别的 V/A/D 三维。请换兼容模型或改用简易VAD。") # ===== 指标与可视化 ===== def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]: dv = v1["valence"] - v2["valence"] da = v1["arousal"] - v2["arousal"] dd = v1["dominance"] - v2["dominance"] l2 = math.sqrt(dv*dv + da*da + dd*dd) n1 = math.sqrt(v1["valence"]**2 + v1["arousal"]**2 + v1["dominance"]**2) n2 = math.sqrt(v2["valence"]**2 + v2["arousal"]**2 + v2["dominance"]**2) cos = (v1["valence"]*v2["valence"] + v1["arousal"]*v2["arousal"] + v1["dominance"]*v2["dominance"]) / (n1*n2) if n1>0 and n2>0 else 0.0 return {"ΔV": dv, "ΔA": da, "ΔD": dd, "L2_distance": l2, "cosine_similarity": cos} def bar_html(s: Dict[str, float], r: Dict[str, float]) -> str: def row(label, sv, rv): return f"""