gary-boon Claude Opus 4.5 commited on
Commit
d0b7e29
·
1 Parent(s): a79cb83

Revert QKV visualization fixes - need better approach for data streaming

Browse files

Reverts commits:
- a79cb83 Add safety checks for missing QKV keys
- decb5ab Limit QKV matrices to top 5 heads per layer
- 9056859 Fix QKV matrix extraction for Mistral/Devstral architecture
- 4ec134b Fix QKV visualization for Mistral/Devstral architecture

The QKV fixes caused response sizes to explode, causing 504 timeouts.
Need to implement a better approach (lazy loading) before re-enabling.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>

Files changed (1) hide show
  1. backend/model_service.py +28 -271
backend/model_service.py CHANGED
@@ -1571,11 +1571,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1571
 
1572
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
1573
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
1574
- # Use explicit head_dim from config if available (Mistral models have this)
1575
- if hasattr(manager.model.config, 'head_dim'):
1576
- head_dim = manager.model.config.head_dim
1577
- else:
1578
- head_dim = d_model // n_heads
1579
 
1580
  # Generation loop with full instrumentation
1581
  layer_data_by_token = [] # Store layer data for each generated token
@@ -1586,7 +1582,6 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1586
  hooks = []
1587
 
1588
  def make_qkv_hook(layer_idx):
1589
- """Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
1590
  def hook(module, input, output):
1591
  try:
1592
  # output shape: [batch, seq_len, 3 * hidden_size]
@@ -1610,122 +1605,18 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1610
  pass
1611
  return hook
1612
 
1613
- def make_separate_q_hook(layer_idx):
1614
- """Hook for separate Q projection (Mistral/LLaMA style)"""
1615
- def hook(module, input, output):
1616
- try:
1617
- if layer_idx not in qkv_captures:
1618
- qkv_captures[layer_idx] = {}
1619
- # Handle both tuple and tensor outputs
1620
- if isinstance(output, tuple):
1621
- out = output[0]
1622
- else:
1623
- out = output
1624
- out = out.detach().cpu()
1625
- # output shape: [batch, seq_len, num_heads * head_dim]
1626
- # If 3D, take first batch element
1627
- if out.dim() == 3:
1628
- out = out[0] # [seq_len, num_heads * head_dim]
1629
- # Reshape to [seq_len, num_heads, head_dim]
1630
- seq_len = out.shape[0]
1631
- out = out.reshape(seq_len, n_heads, head_dim)
1632
- qkv_captures[layer_idx]['q'] = out
1633
- except Exception as e:
1634
- logger.warning(f"Q hook error layer {layer_idx}: {e}")
1635
- return hook
1636
-
1637
- def make_separate_k_hook(layer_idx):
1638
- """Hook for separate K projection (Mistral/LLaMA style)"""
1639
- def hook(module, input, output):
1640
- try:
1641
- if layer_idx not in qkv_captures:
1642
- qkv_captures[layer_idx] = {}
1643
- # Handle both tuple and tensor outputs
1644
- if isinstance(output, tuple):
1645
- out = output[0]
1646
- else:
1647
- out = output
1648
- out = out.detach().cpu()
1649
- # If 3D, take first batch element
1650
- if out.dim() == 3:
1651
- out = out[0] # [seq_len, kv_heads * head_dim]
1652
- # For GQA models, K has fewer heads (kv_heads)
1653
- seq_len = out.shape[0]
1654
- hidden_size = out.shape[1]
1655
- actual_kv_heads = hidden_size // head_dim
1656
- out = out.reshape(seq_len, actual_kv_heads, head_dim)
1657
- # If GQA, repeat KV heads to match Q heads
1658
- if actual_kv_heads != n_heads:
1659
- repeat_factor = n_heads // actual_kv_heads
1660
- out = out.repeat_interleave(repeat_factor, dim=1)
1661
- qkv_captures[layer_idx]['k'] = out
1662
- except Exception as e:
1663
- logger.warning(f"K hook error layer {layer_idx}: {e}")
1664
- return hook
1665
-
1666
- def make_separate_v_hook(layer_idx):
1667
- """Hook for separate V projection (Mistral/LLaMA style)"""
1668
- def hook(module, input, output):
1669
- try:
1670
- if layer_idx not in qkv_captures:
1671
- qkv_captures[layer_idx] = {}
1672
- # Handle both tuple and tensor outputs
1673
- if isinstance(output, tuple):
1674
- out = output[0]
1675
- else:
1676
- out = output
1677
- out = out.detach().cpu()
1678
- # If 3D, take first batch element
1679
- if out.dim() == 3:
1680
- out = out[0] # [seq_len, kv_heads * head_dim]
1681
- # For GQA models, V has fewer heads (kv_heads)
1682
- seq_len = out.shape[0]
1683
- hidden_size = out.shape[1]
1684
- actual_kv_heads = hidden_size // head_dim
1685
- out = out.reshape(seq_len, actual_kv_heads, head_dim)
1686
- # If GQA, repeat KV heads to match Q heads
1687
- if actual_kv_heads != n_heads:
1688
- repeat_factor = n_heads // actual_kv_heads
1689
- out = out.repeat_interleave(repeat_factor, dim=1)
1690
- qkv_captures[layer_idx]['v'] = out
1691
- except Exception as e:
1692
- logger.warning(f"V hook error layer {layer_idx}: {e}")
1693
- return hook
1694
-
1695
- # Register hooks - support both CodeGen and Mistral/Devstral architectures
1696
  try:
1697
- # Try to get layers via adapter first (works for all model types)
1698
- layers = None
1699
- if manager.adapter:
1700
- try:
1701
- layers = manager.adapter._get_layers()
1702
- logger.info(f"Using adapter to get {len(layers)} layers for QKV hooks")
1703
- except Exception:
1704
- pass
1705
-
1706
- # Fallback for CodeGen if adapter doesn't work
1707
- if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
1708
- layers = manager.model.transformer.h
1709
- logger.info(f"Using transformer.h for {len(layers)} layers for QKV hooks")
1710
-
1711
- if layers:
1712
- for layer_idx, layer in enumerate(layers):
1713
- # Mistral/Devstral: separate Q, K, V projections
1714
- if hasattr(layer, 'self_attn'):
1715
- attn = layer.self_attn
1716
- if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
1717
- hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
1718
- hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
1719
- hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
1720
- # CodeGen/GPT-NeoX: combined QKV projection
1721
- elif hasattr(layer, 'attn'):
1722
- if hasattr(layer.attn, 'qkv_proj'):
1723
- hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
1724
- elif hasattr(layer.attn, 'c_attn'):
1725
- # GPT-2 style attention
1726
- hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
1727
-
1728
- logger.info(f"Registered {len(hooks)} QKV hooks")
1729
  except Exception as hook_error:
1730
  logger.warning(f"Could not register QKV hooks: {hook_error}")
1731
 
@@ -1859,16 +1750,11 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1859
  k_matrix = None
1860
  v_matrix = None
1861
  if layer_idx in qkv_captures:
1862
- layer_qkv = qkv_captures[layer_idx]
1863
  # Q/K/V shape: [seq_len, n_heads, head_dim]
1864
  # Convert to float32 for numpy (bfloat16 not supported)
1865
- # Check each key exists (hooks may have failed for some)
1866
- if 'q' in layer_qkv:
1867
- q_matrix = layer_qkv['q'][:, head_idx, :].float().numpy().tolist()
1868
- if 'k' in layer_qkv:
1869
- k_matrix = layer_qkv['k'][:, head_idx, :].float().numpy().tolist()
1870
- if 'v' in layer_qkv:
1871
- v_matrix = layer_qkv['v'][:, head_idx, :].float().numpy().tolist()
1872
 
1873
  critical_heads.append({
1874
  "head_idx": head_idx,
@@ -1887,14 +1773,6 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1887
  # Sort by max_weight (return all heads, frontend will decide how many to display)
1888
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
1889
 
1890
- # Only keep QKV matrices for top 5 heads to avoid massive response sizes
1891
- # (40 layers × 32 heads × 3 matrices × seq_len × head_dim is too much data)
1892
- for i, head in enumerate(critical_heads):
1893
- if i >= 5: # Keep QKV only for top 5 heads
1894
- head["q_matrix"] = None
1895
- head["k_matrix"] = None
1896
- head["v_matrix"] = None
1897
-
1898
  # Detect layer-level pattern (percentage-based for any layer count)
1899
  layer_pattern = None
1900
  layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
@@ -2133,11 +2011,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2133
 
2134
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
2135
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
2136
- # Use explicit head_dim from config if available (Mistral models have this)
2137
- if hasattr(manager.model.config, 'head_dim'):
2138
- head_dim = manager.model.config.head_dim
2139
- else:
2140
- head_dim = d_model // n_heads
2141
 
2142
  # === STAGE 2: GENERATING ===
2143
  layer_data_by_token = []
@@ -2148,7 +2022,6 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2148
  hooks = []
2149
 
2150
  def make_qkv_hook(layer_idx):
2151
- """Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
2152
  def hook(module, input, output):
2153
  try:
2154
  if output.dim() != 3:
@@ -2168,120 +2041,16 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2168
  pass
2169
  return hook
2170
 
2171
- def make_separate_q_hook(layer_idx):
2172
- """Hook for separate Q projection (Mistral/LLaMA style)"""
2173
- def hook(module, input, output):
2174
- try:
2175
- if layer_idx not in qkv_captures:
2176
- qkv_captures[layer_idx] = {}
2177
- # Handle both tuple and tensor outputs
2178
- if isinstance(output, tuple):
2179
- out = output[0]
2180
- else:
2181
- out = output
2182
- out = out.detach().cpu()
2183
- # If 3D, take first batch element
2184
- if out.dim() == 3:
2185
- out = out[0] # [seq_len, num_heads * head_dim]
2186
- # Reshape to [seq_len, num_heads, head_dim]
2187
- seq_len = out.shape[0]
2188
- out = out.reshape(seq_len, n_heads, head_dim)
2189
- qkv_captures[layer_idx]['q'] = out
2190
- except Exception as e:
2191
- logger.warning(f"[Stream] Q hook error layer {layer_idx}: {e}")
2192
- return hook
2193
-
2194
- def make_separate_k_hook(layer_idx):
2195
- """Hook for separate K projection (Mistral/LLaMA style)"""
2196
- def hook(module, input, output):
2197
- try:
2198
- if layer_idx not in qkv_captures:
2199
- qkv_captures[layer_idx] = {}
2200
- # Handle both tuple and tensor outputs
2201
- if isinstance(output, tuple):
2202
- out = output[0]
2203
- else:
2204
- out = output
2205
- out = out.detach().cpu()
2206
- # If 3D, take first batch element
2207
- if out.dim() == 3:
2208
- out = out[0] # [seq_len, kv_heads * head_dim]
2209
- # For GQA models, K has fewer heads (kv_heads)
2210
- seq_len = out.shape[0]
2211
- hidden_size = out.shape[1]
2212
- actual_kv_heads = hidden_size // head_dim
2213
- out = out.reshape(seq_len, actual_kv_heads, head_dim)
2214
- # If GQA, repeat KV heads to match Q heads
2215
- if actual_kv_heads != n_heads:
2216
- repeat_factor = n_heads // actual_kv_heads
2217
- out = out.repeat_interleave(repeat_factor, dim=1)
2218
- qkv_captures[layer_idx]['k'] = out
2219
- except Exception as e:
2220
- logger.warning(f"[Stream] K hook error layer {layer_idx}: {e}")
2221
- return hook
2222
-
2223
- def make_separate_v_hook(layer_idx):
2224
- """Hook for separate V projection (Mistral/LLaMA style)"""
2225
- def hook(module, input, output):
2226
- try:
2227
- if layer_idx not in qkv_captures:
2228
- qkv_captures[layer_idx] = {}
2229
- # Handle both tuple and tensor outputs
2230
- if isinstance(output, tuple):
2231
- out = output[0]
2232
- else:
2233
- out = output
2234
- out = out.detach().cpu()
2235
- # If 3D, take first batch element
2236
- if out.dim() == 3:
2237
- out = out[0] # [seq_len, kv_heads * head_dim]
2238
- # For GQA models, V has fewer heads (kv_heads)
2239
- seq_len = out.shape[0]
2240
- hidden_size = out.shape[1]
2241
- actual_kv_heads = hidden_size // head_dim
2242
- out = out.reshape(seq_len, actual_kv_heads, head_dim)
2243
- # If GQA, repeat KV heads to match Q heads
2244
- if actual_kv_heads != n_heads:
2245
- repeat_factor = n_heads // actual_kv_heads
2246
- out = out.repeat_interleave(repeat_factor, dim=1)
2247
- qkv_captures[layer_idx]['v'] = out
2248
- except Exception as e:
2249
- logger.warning(f"[Stream] V hook error layer {layer_idx}: {e}")
2250
- return hook
2251
-
2252
- # Register hooks - support both CodeGen and Mistral/Devstral architectures
2253
  try:
2254
- # Try to get layers via adapter first (works for all model types)
2255
- layers = None
2256
- if manager.adapter:
2257
- try:
2258
- layers = manager.adapter._get_layers()
2259
- logger.info(f"[Stream] Using adapter to get {len(layers)} layers for QKV hooks")
2260
- except Exception:
2261
- pass
2262
-
2263
- # Fallback for CodeGen if adapter doesn't work
2264
- if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2265
- layers = manager.model.transformer.h
2266
- logger.info(f"[Stream] Using transformer.h for {len(layers)} layers for QKV hooks")
2267
-
2268
- if layers:
2269
- for layer_idx, layer in enumerate(layers):
2270
- # Mistral/Devstral: separate Q, K, V projections
2271
- if hasattr(layer, 'self_attn'):
2272
- attn = layer.self_attn
2273
- if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
2274
- hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
2275
- hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
2276
- hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
2277
- # CodeGen/GPT-NeoX: combined QKV projection
2278
- elif hasattr(layer, 'attn'):
2279
- if hasattr(layer.attn, 'qkv_proj'):
2280
- hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
2281
- elif hasattr(layer.attn, 'c_attn'):
2282
- hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
2283
-
2284
- logger.info(f"[Stream] Registered {len(hooks)} QKV hooks")
2285
  except Exception as hook_error:
2286
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2287
 
@@ -2419,14 +2188,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2419
  k_matrix = None
2420
  v_matrix = None
2421
  if layer_idx in qkv_captures:
2422
- layer_qkv = qkv_captures[layer_idx]
2423
- # Check each key exists (hooks may have failed for some)
2424
- if 'q' in layer_qkv:
2425
- q_matrix = layer_qkv['q'][:, head_idx, :].float().numpy().tolist()
2426
- if 'k' in layer_qkv:
2427
- k_matrix = layer_qkv['k'][:, head_idx, :].float().numpy().tolist()
2428
- if 'v' in layer_qkv:
2429
- v_matrix = layer_qkv['v'][:, head_idx, :].float().numpy().tolist()
2430
 
2431
  critical_heads.append({
2432
  "head_idx": head_idx,
@@ -2441,13 +2205,6 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2441
 
2442
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2443
 
2444
- # Only keep QKV matrices for top 5 heads to avoid massive response sizes
2445
- for i, head in enumerate(critical_heads):
2446
- if i >= 5:
2447
- head["q_matrix"] = None
2448
- head["k_matrix"] = None
2449
- head["v_matrix"] = None
2450
-
2451
  layer_pattern = None
2452
  layer_fraction = (layer_idx + 1) / n_layers
2453
  if layer_idx == 0:
 
1571
 
1572
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
1573
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
1574
+ head_dim = d_model // n_heads
 
 
 
 
1575
 
1576
  # Generation loop with full instrumentation
1577
  layer_data_by_token = [] # Store layer data for each generated token
 
1582
  hooks = []
1583
 
1584
  def make_qkv_hook(layer_idx):
 
1585
  def hook(module, input, output):
1586
  try:
1587
  # output shape: [batch, seq_len, 3 * hidden_size]
 
1605
  pass
1606
  return hook
1607
 
1608
+ # Register hooks on all qkv_proj modules (if available)
1609
+ # This is model-specific - CodeGen uses different architecture
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1610
  try:
1611
+ if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
1612
+ for layer_idx, layer in enumerate(manager.model.transformer.h):
1613
+ if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
1614
+ hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
1615
+ hooks.append(hook)
1616
+ elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
1617
+ # GPT-2 style attention
1618
+ hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
1619
+ hooks.append(hook)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1620
  except Exception as hook_error:
1621
  logger.warning(f"Could not register QKV hooks: {hook_error}")
1622
 
 
1750
  k_matrix = None
1751
  v_matrix = None
1752
  if layer_idx in qkv_captures:
 
1753
  # Q/K/V shape: [seq_len, n_heads, head_dim]
1754
  # Convert to float32 for numpy (bfloat16 not supported)
1755
+ q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy().tolist()
1756
+ k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
1757
+ v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
 
 
 
 
1758
 
1759
  critical_heads.append({
1760
  "head_idx": head_idx,
 
1773
  # Sort by max_weight (return all heads, frontend will decide how many to display)
1774
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
1775
 
 
 
 
 
 
 
 
 
1776
  # Detect layer-level pattern (percentage-based for any layer count)
1777
  layer_pattern = None
1778
  layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
 
2011
 
2012
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
2013
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
2014
+ head_dim = d_model // n_heads
 
 
 
 
2015
 
2016
  # === STAGE 2: GENERATING ===
2017
  layer_data_by_token = []
 
2022
  hooks = []
2023
 
2024
  def make_qkv_hook(layer_idx):
 
2025
  def hook(module, input, output):
2026
  try:
2027
  if output.dim() != 3:
 
2041
  pass
2042
  return hook
2043
 
2044
+ # Register hooks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2045
  try:
2046
+ if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2047
+ for layer_idx, layer in enumerate(manager.model.transformer.h):
2048
+ if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
2049
+ hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
2050
+ hooks.append(hook)
2051
+ elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
2052
+ hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
2053
+ hooks.append(hook)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2054
  except Exception as hook_error:
2055
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2056
 
 
2188
  k_matrix = None
2189
  v_matrix = None
2190
  if layer_idx in qkv_captures:
2191
+ q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy().tolist()
2192
+ k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
2193
+ v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
 
 
 
 
 
2194
 
2195
  critical_heads.append({
2196
  "head_idx": head_idx,
 
2205
 
2206
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2207
 
 
 
 
 
 
 
 
2208
  layer_pattern = None
2209
  layer_fraction = (layer_idx + 1) / n_layers
2210
  if layer_idx == 0: