Hiro / app.py
Shirosawa's picture
Update app.py
75f43fb verified
# app.py — 译文情感 VAD 对比(本地Transformers稳健版 + Router可选 + 简易兜底)
import os, json, math
from typing import Dict, Tuple
import gradio as gr
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
APP_TITLE = "译文情感 VAD 对比"
APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 与差异。默认跑本地Transformers;Router仅在目标模型被Provider托管时可用。"
# ===== 环境与端点(Router,可选)=====
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip()
HF_API_URL = "https://router.huggingface.co/hf-inference"
# ===== HTTP 会话(重试 + 连接关闭)=====
_session = requests.Session()
_session.mount("https://", HTTPAdapter(max_retries=Retry(
total=3, connect=3, read=3, backoff_factor=0.5,
status_forcelist=[502, 503, 504], allowed_methods=frozenset(["POST"])
)))
_session.headers.update({"Connection": "close"})
def _trim(s: str, n: int = 2000) -> str:
return (s or "")[:n]
def _c01(x: float) -> float:
return max(0.0, min(1.0, float(x)))
# ===== 简易VAD(兜底)=====
_POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"]
_NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"]
def simple_vad(text: str) -> Dict[str, float]:
t = text or ""
n = max(1, len(t))
ex = t.count("!") + t.count("!")
q = t.count("?") + t.count("?")
caps = sum(c.isupper() for c in t)
tl = t.lower()
pos = sum(t.count(w) for w in _POS) + sum(tl.count(w) for w in _POS)
neg = sum(t.count(w) for w in _NEG) + sum(tl.count(w) for w in _NEG)
v = 0.5 + 0.12*(pos - neg) - 0.05*q
a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1))
d = 0.4 + 0.4*(len(set(t)) / n)
return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}
# ===== 解析 VAD 结构(用于Router返回)=====
def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
if isinstance(obj, dict):
k = {kk.lower(): vv for kk, vv in obj.items()}
if all(x in k for x in ("valence","arousal","dominance")):
return float(k["valence"]), float(k["arousal"]), float(k["dominance"])
for key in ("embedding","vector","vad"):
if key in k and isinstance(k[key], (list, tuple)) and len(k[key]) >= 3:
return float(k[key][0]), float(k[key][1]), float(k[key][2])
if isinstance(obj, list) and len(obj) >= 3:
if all(isinstance(x, (int, float)) for x in obj[:3]):
return float(obj[0]), float(obj[1]), float(obj[2])
if all(isinstance(x, dict) for x in obj[:3]):
m = {}
for it in obj:
lab = str(it.get("label","")).lower()
sc = it.get("score", None)
if sc is None:
continue
if "valence" in lab or lab == "v": m["valence"] = float(sc)
elif "arousal" in lab or lab == "a": m["arousal"] = float(sc)
elif "dominance" in lab or lab == "d": m["dominance"] = float(sc)
if all(t in m for t in ("valence","arousal","dominance")):
return m["valence"], m["arousal"], m["dominance"]
raise ValueError("无法从模型返回中解析 V/A/D")
# ===== Router 推理(仅当该模型被 Provider 托管时可用)=====
def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
if not HF_TOKEN:
raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。")
payload = {"model": MODEL_ID, "inputs": _trim(text, 2000)}
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
"Connection": "close",
"X-Wait-For-Model": "true",
}
r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout))
if r.status_code == 404:
raise gr.Error("Router 404:该模型未由任何 Inference Provider 托管。改用“本地VAD”或换模型。")
if r.status_code == 503:
raise gr.Error("模型冷启动(503)。稍后重试。")
if r.status_code >= 400:
raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}")
data = r.json()
v, a, d = _parse_vad_from_hf(data)
return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}
# ===== 本地 Transformers(稳健适配)=====
_local = {"tok": None, "model": None, "cfg": None}
def _ensure_local():
if _local["tok"] is not None:
return
# 延迟导入,减少启动时间
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
_local["cfg"] = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
_local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
_local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, trust_remote_code=True)
_local["model"].eval()
def _sig(x: float) -> float:
return 1.0 / (1.0 + math.exp(-x))
def local_vad(text: str) -> Dict[str, float]:
_ensure_local()
import torch
s = _trim(text, 512)
inputs = _local["tok"](s, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
out = _local["model"](**inputs)
# 1) 标准分类输出
if hasattr(out, "logits"):
logits = out.logits.squeeze()
# 1a) 有 id2label 且包含 V/A/D
id2label = getattr(_local["cfg"], "id2label", None)
if id2label and isinstance(id2label, dict):
lab = {int(k): str(v).lower() for k, v in id2label.items()}
scores = logits.tolist() if hasattr(logits, "tolist") else list(logits)
m = {}
for i, sc in enumerate(scores):
name = lab.get(i, "")
if "valence" in name or name == "v": m["valence"] = float(sc)
if "arousal" in name or name == "a": m["arousal"] = float(sc)
if "dominance" in name or name == "d": m["dominance"] = float(sc)
if len(m) == 3:
return {"valence": _c01(_sig(m["valence"])), "arousal": _c01(_sig(m["arousal"])), "dominance": _c01(_sig(m["dominance"]))}
# 1b) 无明确标签,但 num_labels>=3,取前三维
if logits.numel() >= 3:
v, a, d = [float(logits[i].item()) for i in range(3)]
return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}
# 2) 某些自定义模型可能把 VAD 放在 out.vad 或 out[...]
for key in ("vad", "scores", "preds"):
if hasattr(out, key):
vec = getattr(out, key)
try:
vec = list(vec)[:3]
v, a, d = float(vec[0]), float(vec[1]), float(vec[2])
return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}
except Exception:
pass
raise gr.Error("本地VAD解析失败:模型输出不含可识别的 V/A/D 三维。请换兼容模型或改用简易VAD。")
# ===== 指标与可视化 =====
def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]:
dv = v1["valence"] - v2["valence"]
da = v1["arousal"] - v2["arousal"]
dd = v1["dominance"] - v2["dominance"]
l2 = math.sqrt(dv*dv + da*da + dd*dd)
n1 = math.sqrt(v1["valence"]**2 + v1["arousal"]**2 + v1["dominance"]**2)
n2 = math.sqrt(v2["valence"]**2 + v2["arousal"]**2 + v2["dominance"]**2)
cos = (v1["valence"]*v2["valence"] + v1["arousal"]*v2["arousal"] + v1["dominance"]*v2["dominance"]) / (n1*n2) if n1>0 and n2>0 else 0.0
return {"ΔV": dv, "ΔA": da, "ΔD": dd, "L2_distance": l2, "cosine_similarity": cos}
def bar_html(s: Dict[str, float], r: Dict[str, float]) -> str:
def row(label, sv, rv):
return f"""
<div class="row"><div class="lab">{label}</div>
<div class="bars">
<div class="bar s" style="width:{int(100*sv)}%;"><span>学生 {sv:.3f}</span></div>
<div class="bar r" style="width:{int(100*rv)}%;"><span>参考 {rv:.3f}</span></div>
</div>
</div>"""
css = """
<style>
.chart{font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto,Arial}
.row{margin:8px 0}.lab{width:90px;display:inline-block;font-weight:600}
.bars{display:inline-block;width:70%;vertical-align:middle}
.bar{height:20px;margin:4px 0;position:relative;background:#eee;border-radius:6px;overflow:hidden}
.bar.s{background:#cfe7ff}.bar.r{background:#ffd6cc}
.bar span{position:absolute;right:8px;top:0;font-size:12px;line-height:20px;color:#222}
.legend{margin-top:12px;font-size:12px;color:#555}
.legend .swatch{display:inline-block;width:12px;height:12px;vertical-align:middle;margin-right:6px;border-radius:3px}
</style>"""
return f"""{css}
<div class="chart">
{row("Valence", s['valence'], r['valence'])}
{row("Arousal", s['arousal'], r['arousal'])}
{row("Dominance", s['dominance'], r['dominance'])}
<div class="legend"><span class="swatch" style="background:#cfe7ff"></span>学生译文&nbsp;&nbsp;
<span class="swatch" style="background:#ffd6cc"></span>参考译文</div>
</div>"""
# ===== 主流程与诊断 =====
def run(student_text: str, reference_text: str, backend: str):
if not (student_text.strip() or reference_text.strip()):
raise gr.Error("请至少输入一段文本。")
if backend == "本地VAD(Transformers, CPU)":
s = local_vad(student_text or ""); r = local_vad(reference_text or "")
elif backend == "HF Router(服务端推理)":
s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "")
else:
s = simple_vad(student_text or ""); r = simple_vad(reference_text or "")
m = metrics(s, r)
rpt = (f"学生译文 VAD: V={s['valence']:.3f}, A={s['arousal']:.3f}, D={s['dominance']:.3f}\n"
f"参考译文 VAD: V={r['valence']:.3f}, A={r['arousal']:.3f}, D={r['dominance']:.3f}\n"
f"差异: ΔV={m['ΔV']:.3f}, ΔA={m['ΔA']:.3f}, ΔD={m['ΔD']:.3f}\n"
f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}")
return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2)
def diagnose_router():
if not HF_TOKEN:
return "未检测到 HF_TOKEN", ""
try:
res = _session.post(HF_API_URL,
headers={"Authorization": f"Bearer {HF_TOKEN}", "Content-Type":"application/json"},
json={"model": MODEL_ID, "inputs": "ok"},
timeout=(8, 30))
return f"HTTP {res.status_code}", res.text[:500]
except Exception as e:
return f"异常:{type(e).__name__}: {e}", ""
# ===== 界面 =====
with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}") as demo:
gr.Markdown(f"# {APP_TITLE}\n{APP_DESC}")
with gr.Row(elem_classes=["wrap"]):
with gr.Column(scale=5):
student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10)
reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10)
backend = gr.Radio(
["本地VAD(Transformers, CPU)", "HF Router(服务端推理)", "内置简易VAD(备用)"],
value="本地VAD(Transformers, CPU)",
label="分析后端",
)
run_btn = gr.Button("运行对比")
gr.Markdown("### 诊断(Router)")
chk_btn = gr.Button("测试 HF Router")
api_status = gr.Textbox(label="接口状态", lines=1)
api_body = gr.Textbox(label="返回片段", lines=5)
with gr.Column(scale=5):
chart = gr.HTML(label="VAD 对比柱状图")
report = gr.Textbox(label="摘要结果", lines=4)
raw_json = gr.Code(label="JSON 输出", language="json")
run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4)
chk_btn.click(diagnose_router, [], [api_status, api_body], concurrency_limit=2)
demo.queue()
app = demo
if __name__ == "__main__":
demo.launch()