File size: 12,599 Bytes
75f43fb
8497de8
 
 
 
 
 
 
 
6339887
75f43fb
8497de8
75f43fb
8497de8
 
084a24b
8497de8
6339887
 
084a24b
 
 
 
6339887
 
8497de8
 
 
75f43fb
8497de8
 
75f43fb
6339887
45a8814
6339887
 
 
 
 
 
 
 
 
 
 
 
75f43fb
6339887
75f43fb
8497de8
 
 
 
 
 
6339887
8497de8
 
6339887
8497de8
6339887
8497de8
 
6339887
 
 
 
 
 
 
8497de8
 
 
 
75f43fb
084a24b
8497de8
 
084a24b
8497de8
 
 
 
084a24b
8497de8
084a24b
 
75f43fb
8497de8
 
 
 
 
 
75f43fb
 
 
 
8497de8
084a24b
75f43fb
 
 
 
 
 
 
 
 
 
 
084a24b
 
 
 
75f43fb
 
084a24b
75f43fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
084a24b
6339887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75f43fb
6339887
 
 
75f43fb
084a24b
75f43fb
 
6339887
 
 
 
 
 
 
 
 
75f43fb
6bbc97c
6339887
4d5273d
084a24b
 
 
 
6339887
 
 
 
084a24b
45a8814
6339887
 
 
 
 
 
75f43fb
084a24b
6339887
 
 
75f43fb
084a24b
6339887
 
 
 
 
 
45a8814
 
75f43fb
45a8814
 
6339887
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# app.py — 译文情感 VAD 对比(本地Transformers稳健版 + Router可选 + 简易兜底)
import os, json, math
from typing import Dict, Tuple

import gradio as gr
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

APP_TITLE = "译文情感 VAD 对比"
APP_DESC = "左侧输入学生译文与参考译文;右侧显示 V/A/D 与差异。默认跑本地Transformers;Router仅在目标模型被Provider托管时可用。"

# ===== 环境与端点(Router,可选)=====
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
MODEL_ID = os.getenv("VAD_MODEL_ID", "RobroKools/vad-bert").strip()
HF_API_URL = "https://router.huggingface.co/hf-inference"

# ===== HTTP 会话(重试 + 连接关闭)=====
_session = requests.Session()
_session.mount("https://", HTTPAdapter(max_retries=Retry(
    total=3, connect=3, read=3, backoff_factor=0.5,
    status_forcelist=[502, 503, 504], allowed_methods=frozenset(["POST"])
)))
_session.headers.update({"Connection": "close"})

def _trim(s: str, n: int = 2000) -> str:
    return (s or "")[:n]

def _c01(x: float) -> float:
    return max(0.0, min(1.0, float(x)))

# ===== 简易VAD(兜底)=====
_POS = ["good","great","excellent","love","like","happy","joy","awesome","amazing","wonderful","赞","好","喜欢","开心","愉快","优秀","棒","太好了","满意","值得"]
_NEG = ["bad","terrible","awful","hate","dislike","sad","angry","worse","worst","horrible","差","坏","讨厌","生气","愤怒","悲伤","糟糕","失望","不满"]
def simple_vad(text: str) -> Dict[str, float]:
    t = text or ""
    n = max(1, len(t))
    ex = t.count("!") + t.count("!")
    q  = t.count("?") + t.count("?")
    caps = sum(c.isupper() for c in t)
    tl = t.lower()
    pos = sum(t.count(w) for w in _POS) + sum(tl.count(w) for w in _POS)
    neg = sum(t.count(w) for w in _NEG) + sum(tl.count(w) for w in _NEG)
    v = 0.5 + 0.12*(pos - neg) - 0.05*q
    a = 0.3 + 0.7*math.tanh((ex + q + caps) / (n / 30 + 1))
    d = 0.4 + 0.4*(len(set(t)) / n)
    return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}

# ===== 解析 VAD 结构(用于Router返回)=====
def _parse_vad_from_hf(obj) -> Tuple[float, float, float]:
    if isinstance(obj, dict):
        k = {kk.lower(): vv for kk, vv in obj.items()}
        if all(x in k for x in ("valence","arousal","dominance")):
            return float(k["valence"]), float(k["arousal"]), float(k["dominance"])
        for key in ("embedding","vector","vad"):
            if key in k and isinstance(k[key], (list, tuple)) and len(k[key]) >= 3:
                return float(k[key][0]), float(k[key][1]), float(k[key][2])
    if isinstance(obj, list) and len(obj) >= 3:
        if all(isinstance(x, (int, float)) for x in obj[:3]):
            return float(obj[0]), float(obj[1]), float(obj[2])
        if all(isinstance(x, dict) for x in obj[:3]):
            m = {}
            for it in obj:
                lab = str(it.get("label","")).lower()
                sc  = it.get("score", None)
                if sc is None: 
                    continue
                if "valence" in lab or lab == "v": m["valence"] = float(sc)
                elif "arousal" in lab or lab == "a": m["arousal"] = float(sc)
                elif "dominance" in lab or lab == "d": m["dominance"] = float(sc)
            if all(t in m for t in ("valence","arousal","dominance")):
                return m["valence"], m["arousal"], m["dominance"]
    raise ValueError("无法从模型返回中解析 V/A/D")

# ===== Router 推理(仅当该模型被 Provider 托管时可用)=====
def hf_router_vad(text: str, timeout: float = 90.0) -> Dict[str, float]:
    if not HF_TOKEN:
        raise gr.Error("未配置 HF_TOKEN(Settings → Variables & secrets)。")
    payload = {"model": MODEL_ID, "inputs": _trim(text, 2000)}
    headers = {
        "Authorization": f"Bearer {HF_TOKEN}",
        "Content-Type": "application/json",
        "Connection": "close",
        "X-Wait-For-Model": "true",
    }
    r = _session.post(HF_API_URL, headers=headers, json=payload, timeout=(8, timeout))
    if r.status_code == 404:
        raise gr.Error("Router 404:该模型未由任何 Inference Provider 托管。改用“本地VAD”或换模型。")
    if r.status_code == 503:
        raise gr.Error("模型冷启动(503)。稍后重试。")
    if r.status_code >= 400:
        raise gr.Error(f"HF API 错误 {r.status_code}: {r.text[:200]}")
    data = r.json()
    v, a, d = _parse_vad_from_hf(data)
    return {"valence": _c01(v), "arousal": _c01(a), "dominance": _c01(d)}

# ===== 本地 Transformers(稳健适配)=====
_local = {"tok": None, "model": None, "cfg": None}

def _ensure_local():
    if _local["tok"] is not None:
        return
    # 延迟导入,减少启动时间
    from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
    _local["cfg"] = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
    _local["tok"] = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
    _local["model"] = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, trust_remote_code=True)
    _local["model"].eval()

def _sig(x: float) -> float:
    return 1.0 / (1.0 + math.exp(-x))

def local_vad(text: str) -> Dict[str, float]:
    _ensure_local()
    import torch
    s = _trim(text, 512)
    inputs = _local["tok"](s, return_tensors="pt", truncation=True, max_length=256)
    with torch.no_grad():
        out = _local["model"](**inputs)

    # 1) 标准分类输出
    if hasattr(out, "logits"):
        logits = out.logits.squeeze()
        # 1a) 有 id2label 且包含 V/A/D
        id2label = getattr(_local["cfg"], "id2label", None)
        if id2label and isinstance(id2label, dict):
            lab = {int(k): str(v).lower() for k, v in id2label.items()}
            scores = logits.tolist() if hasattr(logits, "tolist") else list(logits)
            m = {}
            for i, sc in enumerate(scores):
                name = lab.get(i, "")
                if "valence" in name or name == "v": m["valence"] = float(sc)
                if "arousal" in name or name == "a": m["arousal"] = float(sc)
                if "dominance" in name or name == "d": m["dominance"] = float(sc)
            if len(m) == 3:
                return {"valence": _c01(_sig(m["valence"])), "arousal": _c01(_sig(m["arousal"])), "dominance": _c01(_sig(m["dominance"]))}
        # 1b) 无明确标签,但 num_labels>=3,取前三维
        if logits.numel() >= 3:
            v, a, d = [float(logits[i].item()) for i in range(3)]
            return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}

    # 2) 某些自定义模型可能把 VAD 放在 out.vad 或 out[...]
    for key in ("vad", "scores", "preds"):
        if hasattr(out, key):
            vec = getattr(out, key)
            try:
                vec = list(vec)[:3]
                v, a, d = float(vec[0]), float(vec[1]), float(vec[2])
                return {"valence": _c01(_sig(v)), "arousal": _c01(_sig(a)), "dominance": _c01(_sig(d))}
            except Exception:
                pass

    raise gr.Error("本地VAD解析失败:模型输出不含可识别的 V/A/D 三维。请换兼容模型或改用简易VAD。")

# ===== 指标与可视化 =====
def metrics(v1: Dict[str, float], v2: Dict[str, float]) -> Dict[str, float]:
    dv = v1["valence"] - v2["valence"]
    da = v1["arousal"] - v2["arousal"]
    dd = v1["dominance"] - v2["dominance"]
    l2 = math.sqrt(dv*dv + da*da + dd*dd)
    n1 = math.sqrt(v1["valence"]**2 + v1["arousal"]**2 + v1["dominance"]**2)
    n2 = math.sqrt(v2["valence"]**2 + v2["arousal"]**2 + v2["dominance"]**2)
    cos = (v1["valence"]*v2["valence"] + v1["arousal"]*v2["arousal"] + v1["dominance"]*v2["dominance"]) / (n1*n2) if n1>0 and n2>0 else 0.0
    return {"ΔV": dv, "ΔA": da, "ΔD": dd, "L2_distance": l2, "cosine_similarity": cos}

def bar_html(s: Dict[str, float], r: Dict[str, float]) -> str:
    def row(label, sv, rv):
        return f"""
        <div class="row"><div class="lab">{label}</div>
          <div class="bars">
            <div class="bar s" style="width:{int(100*sv)}%;"><span>学生 {sv:.3f}</span></div>
            <div class="bar r" style="width:{int(100*rv)}%;"><span>参考 {rv:.3f}</span></div>
          </div>
        </div>"""
    css = """
    <style>
      .chart{font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto,Arial}
      .row{margin:8px 0}.lab{width:90px;display:inline-block;font-weight:600}
      .bars{display:inline-block;width:70%;vertical-align:middle}
      .bar{height:20px;margin:4px 0;position:relative;background:#eee;border-radius:6px;overflow:hidden}
      .bar.s{background:#cfe7ff}.bar.r{background:#ffd6cc}
      .bar span{position:absolute;right:8px;top:0;font-size:12px;line-height:20px;color:#222}
      .legend{margin-top:12px;font-size:12px;color:#555}
      .legend .swatch{display:inline-block;width:12px;height:12px;vertical-align:middle;margin-right:6px;border-radius:3px}
    </style>"""
    return f"""{css}
    <div class="chart">
      {row("Valence", s['valence'], r['valence'])}
      {row("Arousal", s['arousal'], r['arousal'])}
      {row("Dominance", s['dominance'], r['dominance'])}
      <div class="legend"><span class="swatch" style="background:#cfe7ff"></span>学生译文&nbsp;&nbsp;
      <span class="swatch" style="background:#ffd6cc"></span>参考译文</div>
    </div>"""

# ===== 主流程与诊断 =====
def run(student_text: str, reference_text: str, backend: str):
    if not (student_text.strip() or reference_text.strip()):
        raise gr.Error("请至少输入一段文本。")
    if backend == "本地VAD(Transformers, CPU)":
        s = local_vad(student_text or ""); r = local_vad(reference_text or "")
    elif backend == "HF Router(服务端推理)":
        s = hf_router_vad(student_text or ""); r = hf_router_vad(reference_text or "")
    else:
        s = simple_vad(student_text or ""); r = simple_vad(reference_text or "")
    m = metrics(s, r)
    rpt = (f"学生译文 VAD: V={s['valence']:.3f}, A={s['arousal']:.3f}, D={s['dominance']:.3f}\n"
           f"参考译文 VAD: V={r['valence']:.3f}, A={r['arousal']:.3f}, D={r['dominance']:.3f}\n"
           f"差异: ΔV={m['ΔV']:.3f}, ΔA={m['ΔA']:.3f}, ΔD={m['ΔD']:.3f}\n"
           f"L2 距离={m['L2_distance']:.3f},余弦相似度={m['cosine_similarity']:.3f}")
    return bar_html(s, r), rpt, json.dumps({"student": s, "reference": r, "metrics": m}, ensure_ascii=False, indent=2)

def diagnose_router():
    if not HF_TOKEN:
        return "未检测到 HF_TOKEN", ""
    try:
        res = _session.post(HF_API_URL,
                            headers={"Authorization": f"Bearer {HF_TOKEN}", "Content-Type":"application/json"},
                            json={"model": MODEL_ID, "inputs": "ok"},
                            timeout=(8, 30))
        return f"HTTP {res.status_code}", res.text[:500]
    except Exception as e:
        return f"异常:{type(e).__name__}: {e}", ""

# ===== 界面 =====
with gr.Blocks(title=APP_TITLE, css=".wrap {max-width: 1200px; margin: 0 auto;}") as demo:
    gr.Markdown(f"# {APP_TITLE}\n{APP_DESC}")
    with gr.Row(elem_classes=["wrap"]):
        with gr.Column(scale=5):
            student = gr.Textbox(label="学生译文", placeholder="粘贴学生译文…", lines=10)
            reference = gr.Textbox(label="参考译文", placeholder="粘贴参考译文…", lines=10)
            backend = gr.Radio(
                ["本地VAD(Transformers, CPU)", "HF Router(服务端推理)", "内置简易VAD(备用)"],
                value="本地VAD(Transformers, CPU)",
                label="分析后端",
            )
            run_btn = gr.Button("运行对比")
            gr.Markdown("### 诊断(Router)")
            chk_btn = gr.Button("测试 HF Router")
            api_status = gr.Textbox(label="接口状态", lines=1)
            api_body = gr.Textbox(label="返回片段", lines=5)
        with gr.Column(scale=5):
            chart = gr.HTML(label="VAD 对比柱状图")
            report = gr.Textbox(label="摘要结果", lines=4)
            raw_json = gr.Code(label="JSON 输出", language="json")

    run_btn.click(run, [student, reference, backend], [chart, report, raw_json], concurrency_limit=4)
    chk_btn.click(diagnose_router, [], [api_status, api_body], concurrency_limit=2)

demo.queue()
app = demo

if __name__ == "__main__":
    demo.launch()