Spaces:
Sleeping
Add SSE streaming endpoint for real-time analysis progress
Browse filesAdd /analyze/research/attention/stream endpoint that streams progress
events during attention analysis via Server-Sent Events (SSE).
Progress events are emitted at each stage:
- tokenizing: After prompt tokenization, with token count
- generating: Per-token, with step index and total steps
- extracting: Per-layer (on last token), with layer/head counts
- serializing: During response building, with size estimate
- complete: When analysis finishes, with timing and size metadata
- result: Final payload with complete analysis data
Each event includes:
- type, stage, totalStages, progress, stageProgress, detail
- metadata object with stage-specific info (counts, sizes, timing)
Benefits:
- Frontend gets real progress updates instead of fake animation
- Users see exactly what's happening: "Processing layer 15/40"
- Response size displayed during serialization
- No more "stuck at 95%" - progress reflects actual work
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
- backend/model_service.py +455 -0
|
@@ -4,6 +4,7 @@ Combines model loading, generation, and trace extraction into a single service
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends
|
|
|
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from pydantic import BaseModel
|
| 9 |
import asyncio
|
|
@@ -1920,6 +1921,460 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1920 |
logger.error(traceback.format_exc())
|
| 1921 |
raise HTTPException(status_code=500, detail=str(e))
|
| 1922 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1923 |
@app.post("/analyze/study")
|
| 1924 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 1925 |
"""
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends
|
| 7 |
+
from fastapi.responses import StreamingResponse
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel
|
| 10 |
import asyncio
|
|
|
|
| 1921 |
logger.error(traceback.format_exc())
|
| 1922 |
raise HTTPException(status_code=500, detail=str(e))
|
| 1923 |
|
| 1924 |
+
|
| 1925 |
+
def sse_event(event_type: str, **kwargs) -> str:
|
| 1926 |
+
"""Format data as SSE event"""
|
| 1927 |
+
data = {'type': event_type, 'timestamp': int(time.time() * 1000), **kwargs}
|
| 1928 |
+
return f"data: {json.dumps(data)}\n\n"
|
| 1929 |
+
|
| 1930 |
+
|
| 1931 |
+
@app.post("/analyze/research/attention/stream")
|
| 1932 |
+
async def analyze_research_attention_stream(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)):
|
| 1933 |
+
"""
|
| 1934 |
+
SSE Streaming version of Research-Grade Attention Analysis
|
| 1935 |
+
|
| 1936 |
+
Emits progress events during each stage:
|
| 1937 |
+
- tokenizing: Initial tokenization
|
| 1938 |
+
- generating: Per-token generation progress
|
| 1939 |
+
- extracting: Per-layer attention extraction
|
| 1940 |
+
- serializing: Building response
|
| 1941 |
+
- complete: Analysis finished
|
| 1942 |
+
- result: Final data payload
|
| 1943 |
+
"""
|
| 1944 |
+
async def event_generator():
|
| 1945 |
+
try:
|
| 1946 |
+
import time
|
| 1947 |
+
start_time = time.time()
|
| 1948 |
+
|
| 1949 |
+
# Get parameters
|
| 1950 |
+
prompt = request.get("prompt", "def quicksort(arr):")
|
| 1951 |
+
max_tokens = request.get("max_tokens", 8)
|
| 1952 |
+
temperature = request.get("temperature", 0.7)
|
| 1953 |
+
|
| 1954 |
+
logger.info(f"[SSE] Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}")
|
| 1955 |
+
|
| 1956 |
+
# === STAGE 1: TOKENIZING ===
|
| 1957 |
+
yield sse_event('tokenizing', stage=1, totalStages=5, progress=2,
|
| 1958 |
+
stageProgress=0, detail=f'Tokenizing {len(prompt)} characters...')
|
| 1959 |
+
|
| 1960 |
+
# Get model config for prompt formatting
|
| 1961 |
+
from .model_config import get_model_config
|
| 1962 |
+
from .prompt_formatter import format_prompt
|
| 1963 |
+
model_config = get_model_config(manager.model_id)
|
| 1964 |
+
|
| 1965 |
+
# Get optional system prompt override from request
|
| 1966 |
+
system_prompt_override = request.get("system_prompt")
|
| 1967 |
+
|
| 1968 |
+
# Format prompt using the unified formatter
|
| 1969 |
+
formatted_prompt = format_prompt(
|
| 1970 |
+
prompt=prompt,
|
| 1971 |
+
model_config=model_config or {},
|
| 1972 |
+
tokenizer=manager.tokenizer,
|
| 1973 |
+
system_prompt_override=system_prompt_override
|
| 1974 |
+
)
|
| 1975 |
+
|
| 1976 |
+
prompt_style = model_config.get("prompt_style", "completion") if model_config else "completion"
|
| 1977 |
+
|
| 1978 |
+
# Use model's recommended temperature for instruction models
|
| 1979 |
+
if model_config and "recommended_temperature" in model_config:
|
| 1980 |
+
temperature = model_config["recommended_temperature"]
|
| 1981 |
+
|
| 1982 |
+
# Tokenize and prepare - use MistralTokenizer for Devstral
|
| 1983 |
+
if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
|
| 1984 |
+
system_prompt = system_prompt_override or (model_config.get("system_prompt") if model_config else "")
|
| 1985 |
+
prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt)
|
| 1986 |
+
inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)}
|
| 1987 |
+
prompt_length = len(prompt_token_ids)
|
| 1988 |
+
prompt_tokens = [manager.mistral_tokenizer.decode_token(tid) for tid in prompt_token_ids]
|
| 1989 |
+
else:
|
| 1990 |
+
inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device)
|
| 1991 |
+
prompt_length = inputs["input_ids"].shape[1]
|
| 1992 |
+
prompt_token_ids = inputs["input_ids"][0].tolist()
|
| 1993 |
+
prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids]
|
| 1994 |
+
|
| 1995 |
+
yield sse_event('tokenizing', stage=1, totalStages=5, progress=8,
|
| 1996 |
+
stageProgress=100, detail=f'Tokenized into {prompt_length} tokens',
|
| 1997 |
+
metadata={'tokenCount': prompt_length})
|
| 1998 |
+
await asyncio.sleep(0) # Yield to event loop
|
| 1999 |
+
|
| 2000 |
+
# Storage for generation
|
| 2001 |
+
generated_token_ids = []
|
| 2002 |
+
generated_tokens = []
|
| 2003 |
+
|
| 2004 |
+
# Model info
|
| 2005 |
+
n_layers = len(list(manager.model.parameters()))
|
| 2006 |
+
if hasattr(manager.model.config, 'n_layer'):
|
| 2007 |
+
n_layers = manager.model.config.n_layer
|
| 2008 |
+
elif hasattr(manager.model.config, 'num_hidden_layers'):
|
| 2009 |
+
n_layers = manager.model.config.num_hidden_layers
|
| 2010 |
+
|
| 2011 |
+
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 2012 |
+
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 2013 |
+
head_dim = d_model // n_heads
|
| 2014 |
+
|
| 2015 |
+
# === STAGE 2: GENERATING ===
|
| 2016 |
+
layer_data_by_token = []
|
| 2017 |
+
token_alternatives_by_step = []
|
| 2018 |
+
|
| 2019 |
+
# Hook system to capture Q/K/V matrices
|
| 2020 |
+
qkv_captures = {}
|
| 2021 |
+
hooks = []
|
| 2022 |
+
|
| 2023 |
+
def make_qkv_hook(layer_idx):
|
| 2024 |
+
def hook(module, input, output):
|
| 2025 |
+
try:
|
| 2026 |
+
if output.dim() != 3:
|
| 2027 |
+
return
|
| 2028 |
+
batch_size, seq_len, hidden = output.shape
|
| 2029 |
+
expected_hidden = 3 * n_heads * head_dim
|
| 2030 |
+
if hidden != expected_hidden:
|
| 2031 |
+
return
|
| 2032 |
+
qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
|
| 2033 |
+
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
| 2034 |
+
qkv_captures[layer_idx] = {
|
| 2035 |
+
'q': q[0].detach().cpu(),
|
| 2036 |
+
'k': k[0].detach().cpu(),
|
| 2037 |
+
'v': v[0].detach().cpu()
|
| 2038 |
+
}
|
| 2039 |
+
except Exception:
|
| 2040 |
+
pass
|
| 2041 |
+
return hook
|
| 2042 |
+
|
| 2043 |
+
# Register hooks
|
| 2044 |
+
try:
|
| 2045 |
+
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 2046 |
+
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 2047 |
+
if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
|
| 2048 |
+
hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
|
| 2049 |
+
hooks.append(hook)
|
| 2050 |
+
elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
|
| 2051 |
+
hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
|
| 2052 |
+
hooks.append(hook)
|
| 2053 |
+
except Exception as hook_error:
|
| 2054 |
+
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2055 |
+
|
| 2056 |
+
with torch.no_grad():
|
| 2057 |
+
current_ids = inputs["input_ids"]
|
| 2058 |
+
|
| 2059 |
+
for step in range(max_tokens):
|
| 2060 |
+
# Emit progress for this generation step
|
| 2061 |
+
step_progress = (step / max_tokens) * 100
|
| 2062 |
+
overall_progress = 10 + (step / max_tokens) * 20 # 10-30%
|
| 2063 |
+
yield sse_event('generating', stage=2, totalStages=5, progress=overall_progress,
|
| 2064 |
+
stageProgress=step_progress,
|
| 2065 |
+
detail=f'Generating token {step + 1}/{max_tokens}',
|
| 2066 |
+
metadata={'stepIndex': step, 'totalSteps': max_tokens})
|
| 2067 |
+
await asyncio.sleep(0)
|
| 2068 |
+
|
| 2069 |
+
qkv_captures.clear()
|
| 2070 |
+
|
| 2071 |
+
# Forward pass with full outputs
|
| 2072 |
+
outputs = manager.model(
|
| 2073 |
+
current_ids,
|
| 2074 |
+
output_attentions=True,
|
| 2075 |
+
output_hidden_states=True
|
| 2076 |
+
)
|
| 2077 |
+
|
| 2078 |
+
# Get logits for next token
|
| 2079 |
+
logits = outputs.logits[0, -1, :]
|
| 2080 |
+
|
| 2081 |
+
# Apply temperature and sample
|
| 2082 |
+
if temperature > 0:
|
| 2083 |
+
logits = logits / temperature
|
| 2084 |
+
probs = torch.softmax(logits, dim=0)
|
| 2085 |
+
|
| 2086 |
+
if temperature == 0:
|
| 2087 |
+
next_token_id = torch.argmax(probs, dim=-1).item()
|
| 2088 |
+
else:
|
| 2089 |
+
next_token_id = torch.multinomial(probs, 1).item()
|
| 2090 |
+
next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
|
| 2091 |
+
|
| 2092 |
+
generated_token_ids.append(next_token_id)
|
| 2093 |
+
generated_tokens.append(next_token_text)
|
| 2094 |
+
|
| 2095 |
+
# Capture top-k token alternatives
|
| 2096 |
+
import math as math_module
|
| 2097 |
+
top_k = 5
|
| 2098 |
+
top_probs, top_indices = torch.topk(probs, k=min(top_k, len(probs)))
|
| 2099 |
+
alternatives = []
|
| 2100 |
+
for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
|
| 2101 |
+
token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
|
| 2102 |
+
alternatives.append({
|
| 2103 |
+
"token": token_text,
|
| 2104 |
+
"token_id": idx,
|
| 2105 |
+
"probability": prob,
|
| 2106 |
+
"log_probability": math_module.log(prob) if prob > 0 else float('-inf')
|
| 2107 |
+
})
|
| 2108 |
+
token_alternatives_by_step.append({
|
| 2109 |
+
"step": step,
|
| 2110 |
+
"selected_token": next_token_text,
|
| 2111 |
+
"selected_token_id": next_token_id,
|
| 2112 |
+
"alternatives": alternatives
|
| 2113 |
+
})
|
| 2114 |
+
|
| 2115 |
+
# === STAGE 3: EXTRACTING (per layer within each token) ===
|
| 2116 |
+
layer_data_this_token = []
|
| 2117 |
+
|
| 2118 |
+
for layer_idx in range(len(outputs.attentions)):
|
| 2119 |
+
# Emit extraction progress (within generating stage for combined progress)
|
| 2120 |
+
if step == max_tokens - 1: # Only emit detailed layer progress on last token
|
| 2121 |
+
layer_progress = (layer_idx / len(outputs.attentions)) * 100
|
| 2122 |
+
overall_progress = 30 + (layer_idx / len(outputs.attentions)) * 40 # 30-70%
|
| 2123 |
+
yield sse_event('extracting', stage=3, totalStages=5, progress=overall_progress,
|
| 2124 |
+
stageProgress=layer_progress,
|
| 2125 |
+
detail=f'Processing layer {layer_idx + 1}/{len(outputs.attentions)}',
|
| 2126 |
+
metadata={'layerIndex': layer_idx, 'totalLayers': len(outputs.attentions),
|
| 2127 |
+
'headsPerLayer': n_heads, 'stepIndex': step, 'totalSteps': max_tokens})
|
| 2128 |
+
if layer_idx % 5 == 0: # Yield every 5 layers to avoid too many events
|
| 2129 |
+
await asyncio.sleep(0)
|
| 2130 |
+
|
| 2131 |
+
layer_attn = outputs.attentions[layer_idx][0]
|
| 2132 |
+
current_hidden = outputs.hidden_states[layer_idx + 1]
|
| 2133 |
+
if current_hidden.dim() == 3:
|
| 2134 |
+
current_hidden = current_hidden[0]
|
| 2135 |
+
|
| 2136 |
+
if layer_idx > 0:
|
| 2137 |
+
prev_hidden = outputs.hidden_states[layer_idx]
|
| 2138 |
+
if prev_hidden.dim() == 3:
|
| 2139 |
+
prev_hidden = prev_hidden[0]
|
| 2140 |
+
delta_norm = torch.norm(current_hidden - prev_hidden).item()
|
| 2141 |
+
else:
|
| 2142 |
+
delta_norm = None
|
| 2143 |
+
|
| 2144 |
+
activation_magnitude = torch.norm(current_hidden).item()
|
| 2145 |
+
last_token_hidden = current_hidden[-1]
|
| 2146 |
+
activation_entropy = torch.std(last_token_hidden).item()
|
| 2147 |
+
hidden_state_norm = torch.norm(last_token_hidden).item()
|
| 2148 |
+
|
| 2149 |
+
# Sanitize
|
| 2150 |
+
activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude
|
| 2151 |
+
activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy
|
| 2152 |
+
hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm
|
| 2153 |
+
if delta_norm is not None:
|
| 2154 |
+
delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
|
| 2155 |
+
|
| 2156 |
+
# Process heads
|
| 2157 |
+
critical_heads = []
|
| 2158 |
+
for head_idx in range(layer_attn.shape[0]):
|
| 2159 |
+
head_weights = layer_attn[head_idx, -1, :]
|
| 2160 |
+
max_weight = head_weights.max().item()
|
| 2161 |
+
entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
|
| 2162 |
+
|
| 2163 |
+
max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
|
| 2164 |
+
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
| 2165 |
+
|
| 2166 |
+
pattern_type = None
|
| 2167 |
+
confidence = 0.0
|
| 2168 |
+
|
| 2169 |
+
if step > 0 and max_weight > 0.8:
|
| 2170 |
+
pattern_type = "induction"
|
| 2171 |
+
confidence = max_weight
|
| 2172 |
+
elif entropy < 1.0:
|
| 2173 |
+
pattern_type = "positional"
|
| 2174 |
+
confidence = 1.0 - entropy
|
| 2175 |
+
elif 1.0 <= entropy < 2.5:
|
| 2176 |
+
pattern_type = "semantic"
|
| 2177 |
+
confidence = min(1.0, entropy / 2.5)
|
| 2178 |
+
elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
|
| 2179 |
+
pattern_type = "previous_token"
|
| 2180 |
+
confidence = head_weights[-2].item()
|
| 2181 |
+
|
| 2182 |
+
confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
|
| 2183 |
+
|
| 2184 |
+
attention_matrix = layer_attn[head_idx].cpu().float().numpy().tolist()
|
| 2185 |
+
|
| 2186 |
+
q_matrix = None
|
| 2187 |
+
k_matrix = None
|
| 2188 |
+
v_matrix = None
|
| 2189 |
+
if layer_idx in qkv_captures:
|
| 2190 |
+
q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy().tolist()
|
| 2191 |
+
k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
|
| 2192 |
+
v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
|
| 2193 |
+
|
| 2194 |
+
critical_heads.append({
|
| 2195 |
+
"head_idx": head_idx,
|
| 2196 |
+
"entropy": entropy,
|
| 2197 |
+
"max_weight": max_weight,
|
| 2198 |
+
"attention_weights": attention_matrix,
|
| 2199 |
+
"q_matrix": q_matrix,
|
| 2200 |
+
"k_matrix": k_matrix,
|
| 2201 |
+
"v_matrix": v_matrix,
|
| 2202 |
+
"pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
|
| 2203 |
+
})
|
| 2204 |
+
|
| 2205 |
+
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2206 |
+
|
| 2207 |
+
layer_pattern = None
|
| 2208 |
+
layer_fraction = (layer_idx + 1) / n_layers
|
| 2209 |
+
if layer_idx == 0:
|
| 2210 |
+
layer_pattern = {"type": "positional", "confidence": 0.78}
|
| 2211 |
+
elif layer_fraction <= 0.25 and step > 0:
|
| 2212 |
+
layer_pattern = {"type": "previous_token", "confidence": 0.65}
|
| 2213 |
+
elif layer_fraction <= 0.75:
|
| 2214 |
+
layer_pattern = {"type": "induction", "confidence": 0.87}
|
| 2215 |
+
else:
|
| 2216 |
+
layer_pattern = {"type": "semantic", "confidence": 0.92}
|
| 2217 |
+
|
| 2218 |
+
layer_data_this_token.append({
|
| 2219 |
+
"layer_idx": layer_idx,
|
| 2220 |
+
"pattern": layer_pattern,
|
| 2221 |
+
"critical_heads": critical_heads,
|
| 2222 |
+
"activation_magnitude": activation_magnitude,
|
| 2223 |
+
"activation_entropy": activation_entropy,
|
| 2224 |
+
"hidden_state_norm": hidden_state_norm,
|
| 2225 |
+
"delta_norm": delta_norm
|
| 2226 |
+
})
|
| 2227 |
+
|
| 2228 |
+
layer_data_by_token.append(layer_data_this_token)
|
| 2229 |
+
|
| 2230 |
+
# Update inputs
|
| 2231 |
+
next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=manager.device)
|
| 2232 |
+
current_ids = torch.cat([current_ids, next_token_tensor], dim=1)
|
| 2233 |
+
|
| 2234 |
+
# Stop on EOS
|
| 2235 |
+
if next_token_id == manager.tokenizer.eos_token_id:
|
| 2236 |
+
break
|
| 2237 |
+
|
| 2238 |
+
# Clean up hooks
|
| 2239 |
+
for hook in hooks:
|
| 2240 |
+
hook.remove()
|
| 2241 |
+
|
| 2242 |
+
# === STAGE 4: SERIALIZING ===
|
| 2243 |
+
yield sse_event('serializing', stage=4, totalStages=5, progress=75,
|
| 2244 |
+
stageProgress=0, detail='Building response data...')
|
| 2245 |
+
await asyncio.sleep(0)
|
| 2246 |
+
|
| 2247 |
+
qkv_by_layer_head = {}
|
| 2248 |
+
generation_time = time.time() - start_time
|
| 2249 |
+
|
| 2250 |
+
# Calculate token section boundaries
|
| 2251 |
+
total_tokens = prompt_length + len(generated_token_ids)
|
| 2252 |
+
system_prompt_text = system_prompt_override or (model_config.get("system_prompt") if model_config else None)
|
| 2253 |
+
|
| 2254 |
+
system_prompt_end = 0
|
| 2255 |
+
if prompt_style == "instruction" and system_prompt_text:
|
| 2256 |
+
if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
|
| 2257 |
+
try:
|
| 2258 |
+
no_system_tokens = manager.mistral_tokenizer.encode_chat("", prompt)
|
| 2259 |
+
system_prompt_end = prompt_length - len(no_system_tokens)
|
| 2260 |
+
system_prompt_end = max(0, min(system_prompt_end, prompt_length))
|
| 2261 |
+
except Exception:
|
| 2262 |
+
system_prompt_end = 0
|
| 2263 |
+
else:
|
| 2264 |
+
total_chars = len(system_prompt_text or "") + len(prompt)
|
| 2265 |
+
if total_chars > 0:
|
| 2266 |
+
system_ratio = len(system_prompt_text or "") / total_chars
|
| 2267 |
+
system_prompt_end = int(prompt_length * system_ratio)
|
| 2268 |
+
|
| 2269 |
+
token_sections = {
|
| 2270 |
+
"systemPrompt": {
|
| 2271 |
+
"start": 0,
|
| 2272 |
+
"end": system_prompt_end,
|
| 2273 |
+
"text": system_prompt_text,
|
| 2274 |
+
"tokenCount": system_prompt_end
|
| 2275 |
+
},
|
| 2276 |
+
"userPrompt": {
|
| 2277 |
+
"start": system_prompt_end,
|
| 2278 |
+
"end": prompt_length,
|
| 2279 |
+
"text": prompt,
|
| 2280 |
+
"tokenCount": prompt_length - system_prompt_end
|
| 2281 |
+
},
|
| 2282 |
+
"output": {
|
| 2283 |
+
"start": prompt_length,
|
| 2284 |
+
"end": total_tokens,
|
| 2285 |
+
"text": "".join(generated_tokens),
|
| 2286 |
+
"tokenCount": len(generated_token_ids)
|
| 2287 |
+
}
|
| 2288 |
+
}
|
| 2289 |
+
|
| 2290 |
+
yield sse_event('serializing', stage=4, totalStages=5, progress=82,
|
| 2291 |
+
stageProgress=50, detail='Building token metadata...')
|
| 2292 |
+
await asyncio.sleep(0)
|
| 2293 |
+
|
| 2294 |
+
# Build token metadata
|
| 2295 |
+
from .tokenizer_utils import TokenizerMetadata
|
| 2296 |
+
token_metadata_builder = TokenizerMetadata(manager.tokenizer)
|
| 2297 |
+
|
| 2298 |
+
special_token_ids_set = {
|
| 2299 |
+
manager.tokenizer.eos_token_id,
|
| 2300 |
+
manager.tokenizer.bos_token_id,
|
| 2301 |
+
manager.tokenizer.pad_token_id,
|
| 2302 |
+
manager.tokenizer.unk_token_id
|
| 2303 |
+
}
|
| 2304 |
+
|
| 2305 |
+
def build_token_data(token_ids, token_texts, token_type):
|
| 2306 |
+
multi_split_flags = token_metadata_builder.is_multi_split_identifier(token_ids)
|
| 2307 |
+
result = []
|
| 2308 |
+
for i, (tid, t) in enumerate(zip(token_ids, token_texts)):
|
| 2309 |
+
bpe_pieces = token_metadata_builder.get_subword_pieces(tid)
|
| 2310 |
+
result.append({
|
| 2311 |
+
"text": t,
|
| 2312 |
+
"idx": tid,
|
| 2313 |
+
"bytes": len(t.encode('utf-8')),
|
| 2314 |
+
"type": token_type,
|
| 2315 |
+
"bpe_pieces": bpe_pieces,
|
| 2316 |
+
"is_special": tid in special_token_ids_set,
|
| 2317 |
+
"is_multi_split": multi_split_flags[i] if i < len(multi_split_flags) else False,
|
| 2318 |
+
"num_pieces": len(bpe_pieces),
|
| 2319 |
+
})
|
| 2320 |
+
return result
|
| 2321 |
+
|
| 2322 |
+
# Build response
|
| 2323 |
+
response = {
|
| 2324 |
+
"prompt": prompt,
|
| 2325 |
+
"promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
|
| 2326 |
+
"generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
|
| 2327 |
+
"tokenSections": token_sections,
|
| 2328 |
+
"tokenAlternatives": token_alternatives_by_step,
|
| 2329 |
+
"layersDataByStep": layer_data_by_token,
|
| 2330 |
+
"layersData": layer_data_by_token[-1] if layer_data_by_token else [],
|
| 2331 |
+
"qkvData": qkv_by_layer_head,
|
| 2332 |
+
"modelInfo": {
|
| 2333 |
+
"numLayers": n_layers,
|
| 2334 |
+
"numHeads": n_heads,
|
| 2335 |
+
"modelDimension": d_model,
|
| 2336 |
+
"headDim": head_dim,
|
| 2337 |
+
"vocabSize": manager.model.config.vocab_size
|
| 2338 |
+
},
|
| 2339 |
+
"generationTime": generation_time,
|
| 2340 |
+
"numTokensGenerated": len(generated_tokens)
|
| 2341 |
+
}
|
| 2342 |
+
|
| 2343 |
+
# Estimate response size
|
| 2344 |
+
response_json = json.dumps(sanitize_for_json(response))
|
| 2345 |
+
response_size_bytes = len(response_json.encode('utf-8'))
|
| 2346 |
+
|
| 2347 |
+
yield sse_event('serializing', stage=4, totalStages=5, progress=90,
|
| 2348 |
+
stageProgress=100, detail=f'Response ready ({response_size_bytes / 1024 / 1024:.1f}MB)',
|
| 2349 |
+
metadata={'responseSizeBytes': response_size_bytes})
|
| 2350 |
+
await asyncio.sleep(0)
|
| 2351 |
+
|
| 2352 |
+
# === STAGE 5: COMPLETE ===
|
| 2353 |
+
yield sse_event('complete', stage=5, totalStages=5, progress=95,
|
| 2354 |
+
stageProgress=0, detail='Transferring data...',
|
| 2355 |
+
metadata={'responseSizeBytes': response_size_bytes, 'generationTimeMs': int(generation_time * 1000)})
|
| 2356 |
+
|
| 2357 |
+
logger.info(f"✅ [SSE] Research attention analysis complete: {len(generated_tokens)} tokens, {generation_time:.2f}s, {response_size_bytes / 1024 / 1024:.1f}MB")
|
| 2358 |
+
|
| 2359 |
+
# Send final result
|
| 2360 |
+
yield sse_event('result', data=sanitize_for_json(response))
|
| 2361 |
+
|
| 2362 |
+
except Exception as e:
|
| 2363 |
+
logger.error(f"[SSE] Research attention analysis error: {e}")
|
| 2364 |
+
logger.error(traceback.format_exc())
|
| 2365 |
+
yield sse_event('error', detail=str(e), stage=0, totalStages=5, progress=0, stageProgress=0)
|
| 2366 |
+
|
| 2367 |
+
return StreamingResponse(
|
| 2368 |
+
event_generator(),
|
| 2369 |
+
media_type='text/event-stream',
|
| 2370 |
+
headers={
|
| 2371 |
+
'Cache-Control': 'no-cache, no-store, must-revalidate',
|
| 2372 |
+
'Connection': 'keep-alive',
|
| 2373 |
+
'X-Accel-Buffering': 'no', # Disable nginx/proxy buffering
|
| 2374 |
+
}
|
| 2375 |
+
)
|
| 2376 |
+
|
| 2377 |
+
|
| 2378 |
@app.post("/analyze/study")
|
| 2379 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 2380 |
"""
|