Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
d0b7e29
1
Parent(s):
a79cb83
Revert QKV visualization fixes - need better approach for data streaming
Browse filesReverts commits:
- a79cb83 Add safety checks for missing QKV keys
- decb5ab Limit QKV matrices to top 5 heads per layer
- 9056859 Fix QKV matrix extraction for Mistral/Devstral architecture
- 4ec134b Fix QKV visualization for Mistral/Devstral architecture
The QKV fixes caused response sizes to explode, causing 504 timeouts.
Need to implement a better approach (lazy loading) before re-enabling.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
- backend/model_service.py +28 -271
backend/model_service.py
CHANGED
|
@@ -1571,11 +1571,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1571 |
|
| 1572 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 1573 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 1574 |
-
|
| 1575 |
-
if hasattr(manager.model.config, 'head_dim'):
|
| 1576 |
-
head_dim = manager.model.config.head_dim
|
| 1577 |
-
else:
|
| 1578 |
-
head_dim = d_model // n_heads
|
| 1579 |
|
| 1580 |
# Generation loop with full instrumentation
|
| 1581 |
layer_data_by_token = [] # Store layer data for each generated token
|
|
@@ -1586,7 +1582,6 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1586 |
hooks = []
|
| 1587 |
|
| 1588 |
def make_qkv_hook(layer_idx):
|
| 1589 |
-
"""Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
|
| 1590 |
def hook(module, input, output):
|
| 1591 |
try:
|
| 1592 |
# output shape: [batch, seq_len, 3 * hidden_size]
|
|
@@ -1610,122 +1605,18 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1610 |
pass
|
| 1611 |
return hook
|
| 1612 |
|
| 1613 |
-
|
| 1614 |
-
|
| 1615 |
-
def hook(module, input, output):
|
| 1616 |
-
try:
|
| 1617 |
-
if layer_idx not in qkv_captures:
|
| 1618 |
-
qkv_captures[layer_idx] = {}
|
| 1619 |
-
# Handle both tuple and tensor outputs
|
| 1620 |
-
if isinstance(output, tuple):
|
| 1621 |
-
out = output[0]
|
| 1622 |
-
else:
|
| 1623 |
-
out = output
|
| 1624 |
-
out = out.detach().cpu()
|
| 1625 |
-
# output shape: [batch, seq_len, num_heads * head_dim]
|
| 1626 |
-
# If 3D, take first batch element
|
| 1627 |
-
if out.dim() == 3:
|
| 1628 |
-
out = out[0] # [seq_len, num_heads * head_dim]
|
| 1629 |
-
# Reshape to [seq_len, num_heads, head_dim]
|
| 1630 |
-
seq_len = out.shape[0]
|
| 1631 |
-
out = out.reshape(seq_len, n_heads, head_dim)
|
| 1632 |
-
qkv_captures[layer_idx]['q'] = out
|
| 1633 |
-
except Exception as e:
|
| 1634 |
-
logger.warning(f"Q hook error layer {layer_idx}: {e}")
|
| 1635 |
-
return hook
|
| 1636 |
-
|
| 1637 |
-
def make_separate_k_hook(layer_idx):
|
| 1638 |
-
"""Hook for separate K projection (Mistral/LLaMA style)"""
|
| 1639 |
-
def hook(module, input, output):
|
| 1640 |
-
try:
|
| 1641 |
-
if layer_idx not in qkv_captures:
|
| 1642 |
-
qkv_captures[layer_idx] = {}
|
| 1643 |
-
# Handle both tuple and tensor outputs
|
| 1644 |
-
if isinstance(output, tuple):
|
| 1645 |
-
out = output[0]
|
| 1646 |
-
else:
|
| 1647 |
-
out = output
|
| 1648 |
-
out = out.detach().cpu()
|
| 1649 |
-
# If 3D, take first batch element
|
| 1650 |
-
if out.dim() == 3:
|
| 1651 |
-
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 1652 |
-
# For GQA models, K has fewer heads (kv_heads)
|
| 1653 |
-
seq_len = out.shape[0]
|
| 1654 |
-
hidden_size = out.shape[1]
|
| 1655 |
-
actual_kv_heads = hidden_size // head_dim
|
| 1656 |
-
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 1657 |
-
# If GQA, repeat KV heads to match Q heads
|
| 1658 |
-
if actual_kv_heads != n_heads:
|
| 1659 |
-
repeat_factor = n_heads // actual_kv_heads
|
| 1660 |
-
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 1661 |
-
qkv_captures[layer_idx]['k'] = out
|
| 1662 |
-
except Exception as e:
|
| 1663 |
-
logger.warning(f"K hook error layer {layer_idx}: {e}")
|
| 1664 |
-
return hook
|
| 1665 |
-
|
| 1666 |
-
def make_separate_v_hook(layer_idx):
|
| 1667 |
-
"""Hook for separate V projection (Mistral/LLaMA style)"""
|
| 1668 |
-
def hook(module, input, output):
|
| 1669 |
-
try:
|
| 1670 |
-
if layer_idx not in qkv_captures:
|
| 1671 |
-
qkv_captures[layer_idx] = {}
|
| 1672 |
-
# Handle both tuple and tensor outputs
|
| 1673 |
-
if isinstance(output, tuple):
|
| 1674 |
-
out = output[0]
|
| 1675 |
-
else:
|
| 1676 |
-
out = output
|
| 1677 |
-
out = out.detach().cpu()
|
| 1678 |
-
# If 3D, take first batch element
|
| 1679 |
-
if out.dim() == 3:
|
| 1680 |
-
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 1681 |
-
# For GQA models, V has fewer heads (kv_heads)
|
| 1682 |
-
seq_len = out.shape[0]
|
| 1683 |
-
hidden_size = out.shape[1]
|
| 1684 |
-
actual_kv_heads = hidden_size // head_dim
|
| 1685 |
-
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 1686 |
-
# If GQA, repeat KV heads to match Q heads
|
| 1687 |
-
if actual_kv_heads != n_heads:
|
| 1688 |
-
repeat_factor = n_heads // actual_kv_heads
|
| 1689 |
-
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 1690 |
-
qkv_captures[layer_idx]['v'] = out
|
| 1691 |
-
except Exception as e:
|
| 1692 |
-
logger.warning(f"V hook error layer {layer_idx}: {e}")
|
| 1693 |
-
return hook
|
| 1694 |
-
|
| 1695 |
-
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|
| 1696 |
try:
|
| 1697 |
-
|
| 1698 |
-
|
| 1699 |
-
|
| 1700 |
-
|
| 1701 |
-
|
| 1702 |
-
|
| 1703 |
-
|
| 1704 |
-
|
| 1705 |
-
|
| 1706 |
-
# Fallback for CodeGen if adapter doesn't work
|
| 1707 |
-
if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 1708 |
-
layers = manager.model.transformer.h
|
| 1709 |
-
logger.info(f"Using transformer.h for {len(layers)} layers for QKV hooks")
|
| 1710 |
-
|
| 1711 |
-
if layers:
|
| 1712 |
-
for layer_idx, layer in enumerate(layers):
|
| 1713 |
-
# Mistral/Devstral: separate Q, K, V projections
|
| 1714 |
-
if hasattr(layer, 'self_attn'):
|
| 1715 |
-
attn = layer.self_attn
|
| 1716 |
-
if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
|
| 1717 |
-
hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
|
| 1718 |
-
hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
|
| 1719 |
-
hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
|
| 1720 |
-
# CodeGen/GPT-NeoX: combined QKV projection
|
| 1721 |
-
elif hasattr(layer, 'attn'):
|
| 1722 |
-
if hasattr(layer.attn, 'qkv_proj'):
|
| 1723 |
-
hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 1724 |
-
elif hasattr(layer.attn, 'c_attn'):
|
| 1725 |
-
# GPT-2 style attention
|
| 1726 |
-
hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 1727 |
-
|
| 1728 |
-
logger.info(f"Registered {len(hooks)} QKV hooks")
|
| 1729 |
except Exception as hook_error:
|
| 1730 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 1731 |
|
|
@@ -1859,16 +1750,11 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1859 |
k_matrix = None
|
| 1860 |
v_matrix = None
|
| 1861 |
if layer_idx in qkv_captures:
|
| 1862 |
-
layer_qkv = qkv_captures[layer_idx]
|
| 1863 |
# Q/K/V shape: [seq_len, n_heads, head_dim]
|
| 1864 |
# Convert to float32 for numpy (bfloat16 not supported)
|
| 1865 |
-
|
| 1866 |
-
|
| 1867 |
-
|
| 1868 |
-
if 'k' in layer_qkv:
|
| 1869 |
-
k_matrix = layer_qkv['k'][:, head_idx, :].float().numpy().tolist()
|
| 1870 |
-
if 'v' in layer_qkv:
|
| 1871 |
-
v_matrix = layer_qkv['v'][:, head_idx, :].float().numpy().tolist()
|
| 1872 |
|
| 1873 |
critical_heads.append({
|
| 1874 |
"head_idx": head_idx,
|
|
@@ -1887,14 +1773,6 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1887 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 1888 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 1889 |
|
| 1890 |
-
# Only keep QKV matrices for top 5 heads to avoid massive response sizes
|
| 1891 |
-
# (40 layers × 32 heads × 3 matrices × seq_len × head_dim is too much data)
|
| 1892 |
-
for i, head in enumerate(critical_heads):
|
| 1893 |
-
if i >= 5: # Keep QKV only for top 5 heads
|
| 1894 |
-
head["q_matrix"] = None
|
| 1895 |
-
head["k_matrix"] = None
|
| 1896 |
-
head["v_matrix"] = None
|
| 1897 |
-
|
| 1898 |
# Detect layer-level pattern (percentage-based for any layer count)
|
| 1899 |
layer_pattern = None
|
| 1900 |
layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
|
|
@@ -2133,11 +2011,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2133 |
|
| 2134 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 2135 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 2136 |
-
|
| 2137 |
-
if hasattr(manager.model.config, 'head_dim'):
|
| 2138 |
-
head_dim = manager.model.config.head_dim
|
| 2139 |
-
else:
|
| 2140 |
-
head_dim = d_model // n_heads
|
| 2141 |
|
| 2142 |
# === STAGE 2: GENERATING ===
|
| 2143 |
layer_data_by_token = []
|
|
@@ -2148,7 +2022,6 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2148 |
hooks = []
|
| 2149 |
|
| 2150 |
def make_qkv_hook(layer_idx):
|
| 2151 |
-
"""Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
|
| 2152 |
def hook(module, input, output):
|
| 2153 |
try:
|
| 2154 |
if output.dim() != 3:
|
|
@@ -2168,120 +2041,16 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2168 |
pass
|
| 2169 |
return hook
|
| 2170 |
|
| 2171 |
-
|
| 2172 |
-
"""Hook for separate Q projection (Mistral/LLaMA style)"""
|
| 2173 |
-
def hook(module, input, output):
|
| 2174 |
-
try:
|
| 2175 |
-
if layer_idx not in qkv_captures:
|
| 2176 |
-
qkv_captures[layer_idx] = {}
|
| 2177 |
-
# Handle both tuple and tensor outputs
|
| 2178 |
-
if isinstance(output, tuple):
|
| 2179 |
-
out = output[0]
|
| 2180 |
-
else:
|
| 2181 |
-
out = output
|
| 2182 |
-
out = out.detach().cpu()
|
| 2183 |
-
# If 3D, take first batch element
|
| 2184 |
-
if out.dim() == 3:
|
| 2185 |
-
out = out[0] # [seq_len, num_heads * head_dim]
|
| 2186 |
-
# Reshape to [seq_len, num_heads, head_dim]
|
| 2187 |
-
seq_len = out.shape[0]
|
| 2188 |
-
out = out.reshape(seq_len, n_heads, head_dim)
|
| 2189 |
-
qkv_captures[layer_idx]['q'] = out
|
| 2190 |
-
except Exception as e:
|
| 2191 |
-
logger.warning(f"[Stream] Q hook error layer {layer_idx}: {e}")
|
| 2192 |
-
return hook
|
| 2193 |
-
|
| 2194 |
-
def make_separate_k_hook(layer_idx):
|
| 2195 |
-
"""Hook for separate K projection (Mistral/LLaMA style)"""
|
| 2196 |
-
def hook(module, input, output):
|
| 2197 |
-
try:
|
| 2198 |
-
if layer_idx not in qkv_captures:
|
| 2199 |
-
qkv_captures[layer_idx] = {}
|
| 2200 |
-
# Handle both tuple and tensor outputs
|
| 2201 |
-
if isinstance(output, tuple):
|
| 2202 |
-
out = output[0]
|
| 2203 |
-
else:
|
| 2204 |
-
out = output
|
| 2205 |
-
out = out.detach().cpu()
|
| 2206 |
-
# If 3D, take first batch element
|
| 2207 |
-
if out.dim() == 3:
|
| 2208 |
-
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 2209 |
-
# For GQA models, K has fewer heads (kv_heads)
|
| 2210 |
-
seq_len = out.shape[0]
|
| 2211 |
-
hidden_size = out.shape[1]
|
| 2212 |
-
actual_kv_heads = hidden_size // head_dim
|
| 2213 |
-
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 2214 |
-
# If GQA, repeat KV heads to match Q heads
|
| 2215 |
-
if actual_kv_heads != n_heads:
|
| 2216 |
-
repeat_factor = n_heads // actual_kv_heads
|
| 2217 |
-
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 2218 |
-
qkv_captures[layer_idx]['k'] = out
|
| 2219 |
-
except Exception as e:
|
| 2220 |
-
logger.warning(f"[Stream] K hook error layer {layer_idx}: {e}")
|
| 2221 |
-
return hook
|
| 2222 |
-
|
| 2223 |
-
def make_separate_v_hook(layer_idx):
|
| 2224 |
-
"""Hook for separate V projection (Mistral/LLaMA style)"""
|
| 2225 |
-
def hook(module, input, output):
|
| 2226 |
-
try:
|
| 2227 |
-
if layer_idx not in qkv_captures:
|
| 2228 |
-
qkv_captures[layer_idx] = {}
|
| 2229 |
-
# Handle both tuple and tensor outputs
|
| 2230 |
-
if isinstance(output, tuple):
|
| 2231 |
-
out = output[0]
|
| 2232 |
-
else:
|
| 2233 |
-
out = output
|
| 2234 |
-
out = out.detach().cpu()
|
| 2235 |
-
# If 3D, take first batch element
|
| 2236 |
-
if out.dim() == 3:
|
| 2237 |
-
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 2238 |
-
# For GQA models, V has fewer heads (kv_heads)
|
| 2239 |
-
seq_len = out.shape[0]
|
| 2240 |
-
hidden_size = out.shape[1]
|
| 2241 |
-
actual_kv_heads = hidden_size // head_dim
|
| 2242 |
-
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 2243 |
-
# If GQA, repeat KV heads to match Q heads
|
| 2244 |
-
if actual_kv_heads != n_heads:
|
| 2245 |
-
repeat_factor = n_heads // actual_kv_heads
|
| 2246 |
-
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 2247 |
-
qkv_captures[layer_idx]['v'] = out
|
| 2248 |
-
except Exception as e:
|
| 2249 |
-
logger.warning(f"[Stream] V hook error layer {layer_idx}: {e}")
|
| 2250 |
-
return hook
|
| 2251 |
-
|
| 2252 |
-
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|
| 2253 |
try:
|
| 2254 |
-
|
| 2255 |
-
|
| 2256 |
-
|
| 2257 |
-
|
| 2258 |
-
|
| 2259 |
-
|
| 2260 |
-
|
| 2261 |
-
|
| 2262 |
-
|
| 2263 |
-
# Fallback for CodeGen if adapter doesn't work
|
| 2264 |
-
if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 2265 |
-
layers = manager.model.transformer.h
|
| 2266 |
-
logger.info(f"[Stream] Using transformer.h for {len(layers)} layers for QKV hooks")
|
| 2267 |
-
|
| 2268 |
-
if layers:
|
| 2269 |
-
for layer_idx, layer in enumerate(layers):
|
| 2270 |
-
# Mistral/Devstral: separate Q, K, V projections
|
| 2271 |
-
if hasattr(layer, 'self_attn'):
|
| 2272 |
-
attn = layer.self_attn
|
| 2273 |
-
if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
|
| 2274 |
-
hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
|
| 2275 |
-
hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
|
| 2276 |
-
hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
|
| 2277 |
-
# CodeGen/GPT-NeoX: combined QKV projection
|
| 2278 |
-
elif hasattr(layer, 'attn'):
|
| 2279 |
-
if hasattr(layer.attn, 'qkv_proj'):
|
| 2280 |
-
hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 2281 |
-
elif hasattr(layer.attn, 'c_attn'):
|
| 2282 |
-
hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 2283 |
-
|
| 2284 |
-
logger.info(f"[Stream] Registered {len(hooks)} QKV hooks")
|
| 2285 |
except Exception as hook_error:
|
| 2286 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2287 |
|
|
@@ -2419,14 +2188,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2419 |
k_matrix = None
|
| 2420 |
v_matrix = None
|
| 2421 |
if layer_idx in qkv_captures:
|
| 2422 |
-
|
| 2423 |
-
|
| 2424 |
-
|
| 2425 |
-
q_matrix = layer_qkv['q'][:, head_idx, :].float().numpy().tolist()
|
| 2426 |
-
if 'k' in layer_qkv:
|
| 2427 |
-
k_matrix = layer_qkv['k'][:, head_idx, :].float().numpy().tolist()
|
| 2428 |
-
if 'v' in layer_qkv:
|
| 2429 |
-
v_matrix = layer_qkv['v'][:, head_idx, :].float().numpy().tolist()
|
| 2430 |
|
| 2431 |
critical_heads.append({
|
| 2432 |
"head_idx": head_idx,
|
|
@@ -2441,13 +2205,6 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2441 |
|
| 2442 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2443 |
|
| 2444 |
-
# Only keep QKV matrices for top 5 heads to avoid massive response sizes
|
| 2445 |
-
for i, head in enumerate(critical_heads):
|
| 2446 |
-
if i >= 5:
|
| 2447 |
-
head["q_matrix"] = None
|
| 2448 |
-
head["k_matrix"] = None
|
| 2449 |
-
head["v_matrix"] = None
|
| 2450 |
-
|
| 2451 |
layer_pattern = None
|
| 2452 |
layer_fraction = (layer_idx + 1) / n_layers
|
| 2453 |
if layer_idx == 0:
|
|
|
|
| 1571 |
|
| 1572 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 1573 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 1574 |
+
head_dim = d_model // n_heads
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1575 |
|
| 1576 |
# Generation loop with full instrumentation
|
| 1577 |
layer_data_by_token = [] # Store layer data for each generated token
|
|
|
|
| 1582 |
hooks = []
|
| 1583 |
|
| 1584 |
def make_qkv_hook(layer_idx):
|
|
|
|
| 1585 |
def hook(module, input, output):
|
| 1586 |
try:
|
| 1587 |
# output shape: [batch, seq_len, 3 * hidden_size]
|
|
|
|
| 1605 |
pass
|
| 1606 |
return hook
|
| 1607 |
|
| 1608 |
+
# Register hooks on all qkv_proj modules (if available)
|
| 1609 |
+
# This is model-specific - CodeGen uses different architecture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1610 |
try:
|
| 1611 |
+
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 1612 |
+
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 1613 |
+
if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
|
| 1614 |
+
hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
|
| 1615 |
+
hooks.append(hook)
|
| 1616 |
+
elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
|
| 1617 |
+
# GPT-2 style attention
|
| 1618 |
+
hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
|
| 1619 |
+
hooks.append(hook)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1620 |
except Exception as hook_error:
|
| 1621 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 1622 |
|
|
|
|
| 1750 |
k_matrix = None
|
| 1751 |
v_matrix = None
|
| 1752 |
if layer_idx in qkv_captures:
|
|
|
|
| 1753 |
# Q/K/V shape: [seq_len, n_heads, head_dim]
|
| 1754 |
# Convert to float32 for numpy (bfloat16 not supported)
|
| 1755 |
+
q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy().tolist()
|
| 1756 |
+
k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
|
| 1757 |
+
v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1758 |
|
| 1759 |
critical_heads.append({
|
| 1760 |
"head_idx": head_idx,
|
|
|
|
| 1773 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 1774 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 1775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1776 |
# Detect layer-level pattern (percentage-based for any layer count)
|
| 1777 |
layer_pattern = None
|
| 1778 |
layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
|
|
|
|
| 2011 |
|
| 2012 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 2013 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 2014 |
+
head_dim = d_model // n_heads
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2015 |
|
| 2016 |
# === STAGE 2: GENERATING ===
|
| 2017 |
layer_data_by_token = []
|
|
|
|
| 2022 |
hooks = []
|
| 2023 |
|
| 2024 |
def make_qkv_hook(layer_idx):
|
|
|
|
| 2025 |
def hook(module, input, output):
|
| 2026 |
try:
|
| 2027 |
if output.dim() != 3:
|
|
|
|
| 2041 |
pass
|
| 2042 |
return hook
|
| 2043 |
|
| 2044 |
+
# Register hooks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2045 |
try:
|
| 2046 |
+
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 2047 |
+
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 2048 |
+
if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
|
| 2049 |
+
hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
|
| 2050 |
+
hooks.append(hook)
|
| 2051 |
+
elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
|
| 2052 |
+
hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
|
| 2053 |
+
hooks.append(hook)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2054 |
except Exception as hook_error:
|
| 2055 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2056 |
|
|
|
|
| 2188 |
k_matrix = None
|
| 2189 |
v_matrix = None
|
| 2190 |
if layer_idx in qkv_captures:
|
| 2191 |
+
q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].float().numpy().tolist()
|
| 2192 |
+
k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
|
| 2193 |
+
v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2194 |
|
| 2195 |
critical_heads.append({
|
| 2196 |
"head_idx": head_idx,
|
|
|
|
| 2205 |
|
| 2206 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2208 |
layer_pattern = None
|
| 2209 |
layer_fraction = (layer_idx + 1) / n_layers
|
| 2210 |
if layer_idx == 0:
|