|
|
|
|
|
import os, json, math |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import requests |
|
|
from requests.adapters import HTTPAdapter |
|
|
from urllib3.util.retry import Retry |
|
|
|
|
|
APP_TITLE = "译文情感 VAD 对比" |
|
|
APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 与差异。默认跑本地Transformers;Router仅在目标模型被Provider托管时可用。" |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() |
|
|
MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip() |
|
|
HF_API_URL = "https://router.huggingface.co/hf-inference" |
|
|
|
|
|
|
|
|
_session = requests.Session() |
|
|
_session.mount("https://", HTTPAdapter(max_retries=Retry( |
|
|
total=3, connect=3, read=3, backoff_factor=0.5, |
|
|
status_forcelist=[502, 503, 504], allowed_methods=frozenset(["POST"]) |
|
|
))) |
|
|
_session.headers.update({"Connection": "close"}) |
|
|
|
|
|
def _trim(s: str, n: int = 2000) -> str: |
|
|
return (s or "")[:n] |
|
|
|
|
|
def _c01(x: float) -> float: |
|
|
return max(0.0, min(1.0, float(x))) |
|
|
|
|
|
|
|
|
_POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"] |
|
|
_NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"] |
|
|
def simple_vad(text: str) -> Dict[str, float]: |
|
|
t = text or "" |
|
|
n = max(1, len(t)) |
|
|
ex = t.count("!") + t.count("!") |
|
|
q = t.count("?") + t.count("?") |
|
|
caps = sum(c.isupper() for c in t) |
|
|
tl = t.lower() |
|
|
pos = sum(t.count(w) for w in _POS) + sum(tl.count(w) for w in _POS) |
|
|
neg = sum(t.count(w) for w in _NEG) + sum(tl.count(w) for w in _NEG) |
|
|
v = 0.5 + 0.12*(pos - neg) - 0.05*q |
|
|
a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1)) |
|
|
d = 0.4 + 0.4*(len(set(t)) / n) |
|
|
return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)} |
|
|
|
|
|
|
|
|
def _parse_vad_from_hf(obj) -> Tuple[float, float, float]: |
|
|
if isinstance(obj, dict): |
|
|
k = {kk.lower(): vv for kk, vv in obj.items()} |
|
|
if all(x in k for x in ("valence","arousal","dominance")): |
|
|
return float(k["valence"]), float(k["arousal"]), float(k["dominance"]) |
|
|
for key in ("embedding","vector","vad"): |
|
|
if key in k and isinstance(k[key], (list, tuple)) and len(k[key]) >= 3: |
|
|
return float(k[key][0]), float(k[key][1]), float(k[key][2]) |
|
|
if isinstance(obj, list) and len(obj) >= 3: |
|
|
if all(isinstance(x, (int, float)) for x in obj[:3]): |
|
|
return float(obj[0]), float(obj[1]), float(obj[2]) |
|
|
if all(isinstance(x, dict) for x in obj[:3]): |
|
|
m = {} |
|
|
for it in obj: |
|
|
lab = str(it.get("label","")).lower() |
|
|
sc = it.get("score", None) |
|
|
if sc is None: |
|
|
continue |
|
|
if "valence" in lab or lab == "v": m["valence"] = float(sc) |
|
|
elif "arousal" in lab or lab == "a": m["arousal"] = float(sc) |
|
|
elif "dominance" in lab or lab == "d": m["dominance"] = float(sc) |
|
|
if all(t in m for t in ("valence","arousal","dominance")): |
|
|
return m["valence"], m["arousal"], m["dominance"] |
|
|
raise ValueError("无法从模型返回中解析 V/A/D") |
|
|
|
|
|
|
|
|
def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]: |
|
|
if not HF_TOKEN: |
|
|
raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。") |
|
|
payload = {"model": MODEL_ID, "inputs": _trim(text, 2000)} |
|
|
headers = { |
|
|
"Authorization": f"Bearer {HF_TOKEN}", |
|
|
"Content-Type": "application/json", |
|
|
"Connection": "close", |
|
|
"X-Wait-For-Model": "true", |
|
|
} |
|
|
r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout)) |
|
|
if r.status_code == 404: |
|
|
raise gr.Error("Router 404:该模型未由任何 Inference Provider 托管。改用“本地VAD”或换模型。") |
|
|
if r.status_code == 503: |
|
|
raise gr.Error("模型冷启动(503)。稍后重试。") |
|
|
if r.status_code >= 400: |
|
|
raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}") |
|
|
data = r.json() |
|
|
v, a, d = _parse_vad_from_hf(data) |
|
|
return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)} |
|
|
|
|
|
|
|
|
_local = {"tok": None, "model": None, "cfg": None} |
|
|
|
|
|
def _ensure_local(): |
|
|
if _local["tok"] is not None: |
|
|
return |
|
|
|
|
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification |
|
|
_local["cfg"] = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
_local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) |
|
|
_local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
_local["model"].eval() |
|
|
|
|
|
def _sig(x: float) -> float: |
|
|
return 1.0 / (1.0 + math.exp(-x)) |
|
|
|
|
|
def local_vad(text: str) -> Dict[str, float]: |
|
|
_ensure_local() |
|
|
import torch |
|
|
s = _trim(text, 512) |
|
|
inputs = _local["tok"](s, return_tensors="pt", truncation=True, max_length=256) |
|
|
with torch.no_grad(): |
|
|
out = _local["model"](**inputs) |
|
|
|
|
|
|
|
|
if hasattr(out, "logits"): |
|
|
logits = out.logits.squeeze() |
|
|
|
|
|
id2label = getattr(_local["cfg"], "id2label", None) |
|
|
if id2label and isinstance(id2label, dict): |
|
|
lab = {int(k): str(v).lower() for k, v in id2label.items()} |
|
|
scores = logits.tolist() if hasattr(logits, "tolist") else list(logits) |
|
|
m = {} |
|
|
for i, sc in enumerate(scores): |
|
|
name = lab.get(i, "") |
|
|
if "valence" in name or name == "v": m["valence"] = float(sc) |
|
|
if "arousal" in name or name == "a": m["arousal"] = float(sc) |
|
|
if "dominance" in name or name == "d": m["dominance"] = float(sc) |
|
|
if len(m) == 3: |
|
|
return {"valence": _c01(_sig(m["valence"])), "arousal": _c01(_sig(m["arousal"])), "dominance": _c01(_sig(m["dominance"]))} |
|
|
|
|
|
if logits.numel() >= 3: |
|
|
v, a, d = [float(logits[i].item()) for i in range(3)] |
|
|
return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))} |
|
|
|
|
|
|
|
|
for key in ("vad", "scores", "preds"): |
|
|
if hasattr(out, key): |
|
|
vec = getattr(out, key) |
|
|
try: |
|
|
vec = list(vec)[:3] |
|
|
v, a, d = float(vec[0]), float(vec[1]), float(vec[2]) |
|
|
return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))} |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
raise gr.Error("本地VAD解析失败:模型输出不含可识别的 V/A/D 三维。请换兼容模型或改用简易VAD。") |
|
|
|
|
|
|
|
|
def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]: |
|
|
dv = v1["valence"] - v2["valence"] |
|
|
da = v1["arousal"] - v2["arousal"] |
|
|
dd = v1["dominance"] - v2["dominance"] |
|
|
l2 = math.sqrt(dv*dv + da*da + dd*dd) |
|
|
n1 = math.sqrt(v1["valence"]**2 + v1["arousal"]**2 + v1["dominance"]**2) |
|
|
n2 = math.sqrt(v2["valence"]**2 + v2["arousal"]**2 + v2["dominance"]**2) |
|
|
cos = (v1["valence"]*v2["valence"] + v1["arousal"]*v2["arousal"] + v1["dominance"]*v2["dominance"]) / (n1*n2) if n1>0 and n2>0 else 0.0 |
|
|
return {"ΔV": dv, "ΔA": da, "ΔD": dd, "L2_distance": l2, "cosine_similarity": cos} |
|
|
|
|
|
def bar_html(s: Dict[str, float], r: Dict[str, float]) -> str: |
|
|
def row(label, sv, rv): |
|
|
return f""" |
|
|
<div class="row"><div class="lab">{label}</div> |
|
|
<div class="bars"> |
|
|
<div class="bar s" style="width:{int(100*sv)}%;"><span>学生 {sv:.3f}</span></div> |
|
|
<div class="bar r" style="width:{int(100*rv)}%;"><span>参考 {rv:.3f}</span></div> |
|
|
</div> |
|
|
</div>""" |
|
|
css = """ |
|
|
<style> |
|
|
.chart{font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto,Arial} |
|
|
.row{margin:8px 0}.lab{width:90px;display:inline-block;font-weight:600} |
|
|
.bars{display:inline-block;width:70%;vertical-align:middle} |
|
|
.bar{height:20px;margin:4px 0;position:relative;background:#eee;border-radius:6px;overflow:hidden} |
|
|
.bar.s{background:#cfe7ff}.bar.r{background:#ffd6cc} |
|
|
.bar span{position:absolute;right:8px;top:0;font-size:12px;line-height:20px;color:#222} |
|
|
.legend{margin-top:12px;font-size:12px;color:#555} |
|
|
.legend .swatch{display:inline-block;width:12px;height:12px;vertical-align:middle;margin-right:6px;border-radius:3px} |
|
|
</style>""" |
|
|
return f"""{css} |
|
|
<div class="chart"> |
|
|
{row("Valence", s['valence'], r['valence'])} |
|
|
{row("Arousal", s['arousal'], r['arousal'])} |
|
|
{row("Dominance", s['dominance'], r['dominance'])} |
|
|
<div class="legend"><span class="swatch" style="background:#cfe7ff"></span>学生译文 |
|
|
<span class="swatch" style="background:#ffd6cc"></span>参考译文</div> |
|
|
</div>""" |
|
|
|
|
|
|
|
|
def run(student_text: str, reference_text: str, backend: str): |
|
|
if not (student_text.strip() or reference_text.strip()): |
|
|
raise gr.Error("请至少输入一段文本。") |
|
|
if backend == "本地VAD(Transformers, CPU)": |
|
|
s = local_vad(student_text or ""); r = local_vad(reference_text or "") |
|
|
elif backend == "HF Router(服务端推理)": |
|
|
s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "") |
|
|
else: |
|
|
s = simple_vad(student_text or ""); r = simple_vad(reference_text or "") |
|
|
m = metrics(s, r) |
|
|
rpt = (f"学生译文 VAD: V={s['valence']:.3f}, A={s['arousal']:.3f}, D={s['dominance']:.3f}\n" |
|
|
f"参考译文 VAD: V={r['valence']:.3f}, A={r['arousal']:.3f}, D={r['dominance']:.3f}\n" |
|
|
f"差异: ΔV={m['ΔV']:.3f}, ΔA={m['ΔA']:.3f}, ΔD={m['ΔD']:.3f}\n" |
|
|
f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}") |
|
|
return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2) |
|
|
|
|
|
def diagnose_router(): |
|
|
if not HF_TOKEN: |
|
|
return "未检测到 HF_TOKEN", "" |
|
|
try: |
|
|
res = _session.post(HF_API_URL, |
|
|
headers={"Authorization": f"Bearer {HF_TOKEN}", "Content-Type":"application/json"}, |
|
|
json={"model": MODEL_ID, "inputs": "ok"}, |
|
|
timeout=(8, 30)) |
|
|
return f"HTTP {res.status_code}", res.text[:500] |
|
|
except Exception as e: |
|
|
return f"异常:{type(e).__name__}: {e}", "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}") as demo: |
|
|
gr.Markdown(f"# {APP_TITLE}\n{APP_DESC}") |
|
|
with gr.Row(elem_classes=["wrap"]): |
|
|
with gr.Column(scale=5): |
|
|
student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10) |
|
|
reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10) |
|
|
backend = gr.Radio( |
|
|
["本地VAD(Transformers, CPU)", "HF Router(服务端推理)", "内置简易VAD(备用)"], |
|
|
value="本地VAD(Transformers, CPU)", |
|
|
label="分析后端", |
|
|
) |
|
|
run_btn = gr.Button("运行对比") |
|
|
gr.Markdown("### 诊断(Router)") |
|
|
chk_btn = gr.Button("测试 HF Router") |
|
|
api_status = gr.Textbox(label="接口状态", lines=1) |
|
|
api_body = gr.Textbox(label="返回片段", lines=5) |
|
|
with gr.Column(scale=5): |
|
|
chart = gr.HTML(label="VAD 对比柱状图") |
|
|
report = gr.Textbox(label="摘要结果", lines=4) |
|
|
raw_json = gr.Code(label="JSON 输出", language="json") |
|
|
|
|
|
run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4) |
|
|
chk_btn.click(diagnose_router, [], [api_status, api_body], concurrency_limit=2) |
|
|
|
|
|
demo.queue() |
|
|
app = demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |