gary-boon Claude Opus 4.5 commited on
Commit
172a186
·
1 Parent(s): ee0f6c9

Add SSE streaming endpoint for real-time analysis progress

Browse files

Add /analyze/research/attention/stream endpoint that streams progress
events during attention analysis via Server-Sent Events (SSE).

Progress events are emitted at each stage:
- tokenizing: After prompt tokenization, with token count
- generating: Per-token, with step index and total steps
- extracting: Per-layer (on last token), with layer/head counts
- serializing: During response building, with size estimate
- complete: When analysis finishes, with timing and size metadata
- result: Final payload with complete analysis data

Each event includes:
- type, stage, totalStages, progress, stageProgress, detail
- metadata object with stage-specific info (counts, sizes, timing)

Benefits:
- Frontend gets real progress updates instead of fake animation
- Users see exactly what's happening: "Processing layer 15/40"
- Response size displayed during serialization
- No more "stuck at 95%" - progress reflects actual work

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>

Files changed (1) hide show
  1. backend/model_service.py +455 -0
backend/model_service.py CHANGED
@@ -4,6 +4,7 @@ Combines model loading, generation, and trace extraction into a single service
4
  """
5
 
6
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
  import asyncio
@@ -1920,6 +1921,460 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1920
  logger.error(traceback.format_exc())
1921
  raise HTTPException(status_code=500, detail=str(e))
1922
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1923
  @app.post("/analyze/study")
1924
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
1925
  """
 
4
  """
5
 
6
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends
7
+ from fastapi.responses import StreamingResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
  import asyncio
 
1921
  logger.error(traceback.format_exc())
1922
  raise HTTPException(status_code=500, detail=str(e))
1923
 
1924
+
1925
+ def sse_event(event_type: str, **kwargs) -> str:
1926
+ """Format data as SSE event"""
1927
+ data = {'type': event_type, 'timestamp': int(time.time() * 1000), **kwargs}
1928
+ return f"data: {json.dumps(data)}\n\n"
1929
+
1930
+
1931
+ @app.post("/analyze/research/attention/stream")
1932
+ async def analyze_research_attention_stream(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)):
1933
+ """
1934
+ SSE Streaming version of Research-Grade Attention Analysis
1935
+
1936
+ Emits progress events during each stage:
1937
+ - tokenizing: Initial tokenization
1938
+ - generating: Per-token generation progress
1939
+ - extracting: Per-layer attention extraction
1940
+ - serializing: Building response
1941
+ - complete: Analysis finished
1942
+ - result: Final data payload
1943
+ """
1944
+ async def event_generator():
1945
+ try:
1946
+ import time
1947
+ start_time = time.time()
1948
+
1949
+ # Get parameters
1950
+ prompt = request.get("prompt", "def quicksort(arr):")
1951
+ max_tokens = request.get("max_tokens", 8)
1952
+ temperature = request.get("temperature", 0.7)
1953
+
1954
+ logger.info(f"[SSE] Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}")
1955
+
1956
+ # === STAGE 1: TOKENIZING ===
1957
+ yield sse_event('tokenizing', stage=1, totalStages=5, progress=2,
1958
+ stageProgress=0, detail=f'Tokenizing {len(prompt)} characters...')
1959
+
1960
+ # Get model config for prompt formatting
1961
+ from .model_config import get_model_config
1962
+ from .prompt_formatter import format_prompt
1963
+ model_config = get_model_config(manager.model_id)
1964
+
1965
+ # Get optional system prompt override from request
1966
+ system_prompt_override = request.get("system_prompt")
1967
+
1968
+ # Format prompt using the unified formatter
1969
+ formatted_prompt = format_prompt(
1970
+ prompt=prompt,
1971
+ model_config=model_config or {},
1972
+ tokenizer=manager.tokenizer,
1973
+ system_prompt_override=system_prompt_override
1974
+ )
1975
+
1976
+ prompt_style = model_config.get("prompt_style", "completion") if model_config else "completion"
1977
+
1978
+ # Use model's recommended temperature for instruction models
1979
+ if model_config and "recommended_temperature" in model_config:
1980
+ temperature = model_config["recommended_temperature"]
1981
+
1982
+ # Tokenize and prepare - use MistralTokenizer for Devstral
1983
+ if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
1984
+ system_prompt = system_prompt_override or (model_config.get("system_prompt") if model_config else "")
1985
+ prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt)
1986
+ inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)}
1987
+ prompt_length = len(prompt_token_ids)
1988
+ prompt_tokens = [manager.mistral_tokenizer.decode_token(tid) for tid in prompt_token_ids]
1989
+ else:
1990
+ inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device)
1991
+ prompt_length = inputs["input_ids"].shape[1]
1992
+ prompt_token_ids = inputs["input_ids"][0].tolist()
1993
+ prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids]
1994
+
1995
+ yield sse_event('tokenizing', stage=1, totalStages=5, progress=8,
1996
+ stageProgress=100, detail=f'Tokenized into {prompt_length} tokens',
1997
+ metadata={'tokenCount': prompt_length})
1998
+ await asyncio.sleep(0) # Yield to event loop
1999
+
2000
+ # Storage for generation
2001
+ generated_token_ids = []
2002
+ generated_tokens = []
2003
+
2004
+ # Model info
2005
+ n_layers = len(list(manager.model.parameters()))
2006
+ if hasattr(manager.model.config, 'n_layer'):
2007
+ n_layers = manager.model.config.n_layer
2008
+ elif hasattr(manager.model.config, 'num_hidden_layers'):
2009
+ n_layers = manager.model.config.num_hidden_layers
2010
+
2011
+ n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
2012
+ d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
2013
+ head_dim = d_model // n_heads
2014
+
2015
+ # === STAGE 2: GENERATING ===
2016
+ layer_data_by_token = []
2017
+ token_alternatives_by_step = []
2018
+
2019
+ # Hook system to capture Q/K/V matrices
2020
+ qkv_captures = {}
2021
+ hooks = []
2022
+
2023
+ def make_qkv_hook(layer_idx):
2024
+ def hook(module, input, output):
2025
+ try:
2026
+ if output.dim() != 3:
2027
+ return
2028
+ batch_size, seq_len, hidden = output.shape
2029
+ expected_hidden = 3 * n_heads * head_dim
2030
+ if hidden != expected_hidden:
2031
+ return
2032
+ qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
2033
+ q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
2034
+ qkv_captures[layer_idx] = {
2035
+ 'q': q[0].detach().cpu(),
2036
+ 'k': k[0].detach().cpu(),
2037
+ 'v': v[0].detach().cpu()
2038
+ }
2039
+ except Exception:
2040
+ pass
2041
+ return hook
2042
+
2043
+ # Register hooks
2044
+ try:
2045
+ if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2046
+ for layer_idx, layer in enumerate(manager.model.transformer.h):
2047
+ if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
2048
+ hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
2049
+ hooks.append(hook)
2050
+ elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
2051
+ hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
2052
+ hooks.append(hook)
2053
+ except Exception as hook_error:
2054
+ logger.warning(f"Could not register QKV hooks: {hook_error}")
2055
+
2056
+ with torch.no_grad():
2057
+ current_ids = inputs["input_ids"]
2058
+
2059
+ for step in range(max_tokens):
2060
+ # Emit progress for this generation step
2061
+ step_progress = (step / max_tokens) * 100
2062
+ overall_progress = 10 + (step / max_tokens) * 20 # 10-30%
2063
+ yield sse_event('generating', stage=2, totalStages=5, progress=overall_progress,
2064
+ stageProgress=step_progress,
2065
+ detail=f'Generating token {step + 1}/{max_tokens}',
2066
+ metadata={'stepIndex': step, 'totalSteps': max_tokens})
2067
+ await asyncio.sleep(0)
2068
+
2069
+ qkv_captures.clear()
2070
+
2071
+ # Forward pass with full outputs
2072
+ outputs = manager.model(
2073
+ current_ids,
2074
+ output_attentions=True,
2075
+ output_hidden_states=True
2076
+ )
2077
+
2078
+ # Get logits for next token
2079
+ logits = outputs.logits[0, -1, :]
2080
+
2081
+ # Apply temperature and sample
2082
+ if temperature > 0:
2083
+ logits = logits / temperature
2084
+ probs = torch.softmax(logits, dim=0)
2085
+
2086
+ if temperature == 0:
2087
+ next_token_id = torch.argmax(probs, dim=-1).item()
2088
+ else:
2089
+ next_token_id = torch.multinomial(probs, 1).item()
2090
+ next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
2091
+
2092
+ generated_token_ids.append(next_token_id)
2093
+ generated_tokens.append(next_token_text)
2094
+
2095
+ # Capture top-k token alternatives
2096
+ import math as math_module
2097
+ top_k = 5
2098
+ top_probs, top_indices = torch.topk(probs, k=min(top_k, len(probs)))
2099
+ alternatives = []
2100
+ for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
2101
+ token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
2102
+ alternatives.append({
2103
+ "token": token_text,
2104
+ "token_id": idx,
2105
+ "probability": prob,
2106
+ "log_probability": math_module.log(prob) if prob > 0 else float('-inf')
2107
+ })
2108
+ token_alternatives_by_step.append({
2109
+ "step": step,
2110
+ "selected_token": next_token_text,
2111
+ "selected_token_id": next_token_id,
2112
+ "alternatives": alternatives
2113
+ })
2114
+
2115
+ # === STAGE 3: EXTRACTING (per layer within each token) ===
2116
+ layer_data_this_token = []
2117
+
2118
+ for layer_idx in range(len(outputs.attentions)):
2119
+ # Emit extraction progress (within generating stage for combined progress)
2120
+ if step == max_tokens - 1: # Only emit detailed layer progress on last token
2121
+ layer_progress = (layer_idx / len(outputs.attentions)) * 100
2122
+ overall_progress = 30 + (layer_idx / len(outputs.attentions)) * 40 # 30-70%
2123
+ yield sse_event('extracting', stage=3, totalStages=5, progress=overall_progress,
2124
+ stageProgress=layer_progress,
2125
+ detail=f'Processing layer {layer_idx + 1}/{len(outputs.attentions)}',
2126
+ metadata={'layerIndex': layer_idx, 'totalLayers': len(outputs.attentions),
2127
+ 'headsPerLayer': n_heads, 'stepIndex': step, 'totalSteps': max_tokens})
2128
+ if layer_idx % 5 == 0: # Yield every 5 layers to avoid too many events
2129
+ await asyncio.sleep(0)
2130
+
2131
+ layer_attn = outputs.attentions[layer_idx][0]
2132
+ current_hidden = outputs.hidden_states[layer_idx + 1]
2133
+ if current_hidden.dim() == 3:
2134
+ current_hidden = current_hidden[0]
2135
+
2136
+ if layer_idx > 0:
2137
+ prev_hidden = outputs.hidden_states[layer_idx]
2138
+ if prev_hidden.dim() == 3:
2139
+ prev_hidden = prev_hidden[0]
2140
+ delta_norm = torch.norm(current_hidden - prev_hidden).item()
2141
+ else:
2142
+ delta_norm = None
2143
+
2144
+ activation_magnitude = torch.norm(current_hidden).item()
2145
+ last_token_hidden = current_hidden[-1]
2146
+ activation_entropy = torch.std(last_token_hidden).item()
2147
+ hidden_state_norm = torch.norm(last_token_hidden).item()
2148
+
2149
+ # Sanitize
2150
+ activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude
2151
+ activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy
2152
+ hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm
2153
+ if delta_norm is not None:
2154
+ delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm
2155
+
2156
+ # Process heads
2157
+ critical_heads = []
2158
+ for head_idx in range(layer_attn.shape[0]):
2159
+ head_weights = layer_attn[head_idx, -1, :]
2160
+ max_weight = head_weights.max().item()
2161
+ entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
2162
+
2163
+ max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
2164
+ entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
2165
+
2166
+ pattern_type = None
2167
+ confidence = 0.0
2168
+
2169
+ if step > 0 and max_weight > 0.8:
2170
+ pattern_type = "induction"
2171
+ confidence = max_weight
2172
+ elif entropy < 1.0:
2173
+ pattern_type = "positional"
2174
+ confidence = 1.0 - entropy
2175
+ elif 1.0 <= entropy < 2.5:
2176
+ pattern_type = "semantic"
2177
+ confidence = min(1.0, entropy / 2.5)
2178
+ elif max_weight > 0.9 and head_weights[-2].item() > 0.85:
2179
+ pattern_type = "previous_token"
2180
+ confidence = head_weights[-2].item()
2181
+
2182
+ confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence
2183
+
2184
+ attention_matrix = layer_attn[head_idx].cpu().float().numpy().tolist()
2185
+
2186
+ q_matrix = None
2187
+ k_matrix = None
2188
+ v_matrix = None
2189
+ if layer_idx in qkv_captures:
2190
+ q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy().tolist()
2191
+ k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
2192
+ v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
2193
+
2194
+ critical_heads.append({
2195
+ "head_idx": head_idx,
2196
+ "entropy": entropy,
2197
+ "max_weight": max_weight,
2198
+ "attention_weights": attention_matrix,
2199
+ "q_matrix": q_matrix,
2200
+ "k_matrix": k_matrix,
2201
+ "v_matrix": v_matrix,
2202
+ "pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
2203
+ })
2204
+
2205
+ critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2206
+
2207
+ layer_pattern = None
2208
+ layer_fraction = (layer_idx + 1) / n_layers
2209
+ if layer_idx == 0:
2210
+ layer_pattern = {"type": "positional", "confidence": 0.78}
2211
+ elif layer_fraction <= 0.25 and step > 0:
2212
+ layer_pattern = {"type": "previous_token", "confidence": 0.65}
2213
+ elif layer_fraction <= 0.75:
2214
+ layer_pattern = {"type": "induction", "confidence": 0.87}
2215
+ else:
2216
+ layer_pattern = {"type": "semantic", "confidence": 0.92}
2217
+
2218
+ layer_data_this_token.append({
2219
+ "layer_idx": layer_idx,
2220
+ "pattern": layer_pattern,
2221
+ "critical_heads": critical_heads,
2222
+ "activation_magnitude": activation_magnitude,
2223
+ "activation_entropy": activation_entropy,
2224
+ "hidden_state_norm": hidden_state_norm,
2225
+ "delta_norm": delta_norm
2226
+ })
2227
+
2228
+ layer_data_by_token.append(layer_data_this_token)
2229
+
2230
+ # Update inputs
2231
+ next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=manager.device)
2232
+ current_ids = torch.cat([current_ids, next_token_tensor], dim=1)
2233
+
2234
+ # Stop on EOS
2235
+ if next_token_id == manager.tokenizer.eos_token_id:
2236
+ break
2237
+
2238
+ # Clean up hooks
2239
+ for hook in hooks:
2240
+ hook.remove()
2241
+
2242
+ # === STAGE 4: SERIALIZING ===
2243
+ yield sse_event('serializing', stage=4, totalStages=5, progress=75,
2244
+ stageProgress=0, detail='Building response data...')
2245
+ await asyncio.sleep(0)
2246
+
2247
+ qkv_by_layer_head = {}
2248
+ generation_time = time.time() - start_time
2249
+
2250
+ # Calculate token section boundaries
2251
+ total_tokens = prompt_length + len(generated_token_ids)
2252
+ system_prompt_text = system_prompt_override or (model_config.get("system_prompt") if model_config else None)
2253
+
2254
+ system_prompt_end = 0
2255
+ if prompt_style == "instruction" and system_prompt_text:
2256
+ if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2257
+ try:
2258
+ no_system_tokens = manager.mistral_tokenizer.encode_chat("", prompt)
2259
+ system_prompt_end = prompt_length - len(no_system_tokens)
2260
+ system_prompt_end = max(0, min(system_prompt_end, prompt_length))
2261
+ except Exception:
2262
+ system_prompt_end = 0
2263
+ else:
2264
+ total_chars = len(system_prompt_text or "") + len(prompt)
2265
+ if total_chars > 0:
2266
+ system_ratio = len(system_prompt_text or "") / total_chars
2267
+ system_prompt_end = int(prompt_length * system_ratio)
2268
+
2269
+ token_sections = {
2270
+ "systemPrompt": {
2271
+ "start": 0,
2272
+ "end": system_prompt_end,
2273
+ "text": system_prompt_text,
2274
+ "tokenCount": system_prompt_end
2275
+ },
2276
+ "userPrompt": {
2277
+ "start": system_prompt_end,
2278
+ "end": prompt_length,
2279
+ "text": prompt,
2280
+ "tokenCount": prompt_length - system_prompt_end
2281
+ },
2282
+ "output": {
2283
+ "start": prompt_length,
2284
+ "end": total_tokens,
2285
+ "text": "".join(generated_tokens),
2286
+ "tokenCount": len(generated_token_ids)
2287
+ }
2288
+ }
2289
+
2290
+ yield sse_event('serializing', stage=4, totalStages=5, progress=82,
2291
+ stageProgress=50, detail='Building token metadata...')
2292
+ await asyncio.sleep(0)
2293
+
2294
+ # Build token metadata
2295
+ from .tokenizer_utils import TokenizerMetadata
2296
+ token_metadata_builder = TokenizerMetadata(manager.tokenizer)
2297
+
2298
+ special_token_ids_set = {
2299
+ manager.tokenizer.eos_token_id,
2300
+ manager.tokenizer.bos_token_id,
2301
+ manager.tokenizer.pad_token_id,
2302
+ manager.tokenizer.unk_token_id
2303
+ }
2304
+
2305
+ def build_token_data(token_ids, token_texts, token_type):
2306
+ multi_split_flags = token_metadata_builder.is_multi_split_identifier(token_ids)
2307
+ result = []
2308
+ for i, (tid, t) in enumerate(zip(token_ids, token_texts)):
2309
+ bpe_pieces = token_metadata_builder.get_subword_pieces(tid)
2310
+ result.append({
2311
+ "text": t,
2312
+ "idx": tid,
2313
+ "bytes": len(t.encode('utf-8')),
2314
+ "type": token_type,
2315
+ "bpe_pieces": bpe_pieces,
2316
+ "is_special": tid in special_token_ids_set,
2317
+ "is_multi_split": multi_split_flags[i] if i < len(multi_split_flags) else False,
2318
+ "num_pieces": len(bpe_pieces),
2319
+ })
2320
+ return result
2321
+
2322
+ # Build response
2323
+ response = {
2324
+ "prompt": prompt,
2325
+ "promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
2326
+ "generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
2327
+ "tokenSections": token_sections,
2328
+ "tokenAlternatives": token_alternatives_by_step,
2329
+ "layersDataByStep": layer_data_by_token,
2330
+ "layersData": layer_data_by_token[-1] if layer_data_by_token else [],
2331
+ "qkvData": qkv_by_layer_head,
2332
+ "modelInfo": {
2333
+ "numLayers": n_layers,
2334
+ "numHeads": n_heads,
2335
+ "modelDimension": d_model,
2336
+ "headDim": head_dim,
2337
+ "vocabSize": manager.model.config.vocab_size
2338
+ },
2339
+ "generationTime": generation_time,
2340
+ "numTokensGenerated": len(generated_tokens)
2341
+ }
2342
+
2343
+ # Estimate response size
2344
+ response_json = json.dumps(sanitize_for_json(response))
2345
+ response_size_bytes = len(response_json.encode('utf-8'))
2346
+
2347
+ yield sse_event('serializing', stage=4, totalStages=5, progress=90,
2348
+ stageProgress=100, detail=f'Response ready ({response_size_bytes / 1024 / 1024:.1f}MB)',
2349
+ metadata={'responseSizeBytes': response_size_bytes})
2350
+ await asyncio.sleep(0)
2351
+
2352
+ # === STAGE 5: COMPLETE ===
2353
+ yield sse_event('complete', stage=5, totalStages=5, progress=95,
2354
+ stageProgress=0, detail='Transferring data...',
2355
+ metadata={'responseSizeBytes': response_size_bytes, 'generationTimeMs': int(generation_time * 1000)})
2356
+
2357
+ logger.info(f"✅ [SSE] Research attention analysis complete: {len(generated_tokens)} tokens, {generation_time:.2f}s, {response_size_bytes / 1024 / 1024:.1f}MB")
2358
+
2359
+ # Send final result
2360
+ yield sse_event('result', data=sanitize_for_json(response))
2361
+
2362
+ except Exception as e:
2363
+ logger.error(f"[SSE] Research attention analysis error: {e}")
2364
+ logger.error(traceback.format_exc())
2365
+ yield sse_event('error', detail=str(e), stage=0, totalStages=5, progress=0, stageProgress=0)
2366
+
2367
+ return StreamingResponse(
2368
+ event_generator(),
2369
+ media_type='text/event-stream',
2370
+ headers={
2371
+ 'Cache-Control': 'no-cache, no-store, must-revalidate',
2372
+ 'Connection': 'keep-alive',
2373
+ 'X-Accel-Buffering': 'no', # Disable nginx/proxy buffering
2374
+ }
2375
+ )
2376
+
2377
+
2378
  @app.post("/analyze/study")
2379
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
2380
  """