Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
4ec134b
1
Parent(s):
3e67ea2
Fix QKV visualization for Mistral/Devstral architecture
Browse files- Add separate Q, K, V hook makers for Mistral-style architectures
that use separate projections (q_proj, k_proj, v_proj)
- Update layer iteration to use adapter's _get_layers() method
for architecture-agnostic layer access
- Maintain backwards compatibility with CodeGen's combined QKV projection
- Apply same fix to both regular and streaming endpoints
This fixes the QKV visualization which was silently failing for Devstral
because the hook registration only supported CodeGen's transformer.h structure.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
- backend/model_service.py +134 -20
backend/model_service.py
CHANGED
|
@@ -1582,6 +1582,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1582 |
hooks = []
|
| 1583 |
|
| 1584 |
def make_qkv_hook(layer_idx):
|
|
|
|
| 1585 |
def hook(module, input, output):
|
| 1586 |
try:
|
| 1587 |
# output shape: [batch, seq_len, 3 * hidden_size]
|
|
@@ -1605,18 +1606,74 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1605 |
pass
|
| 1606 |
return hook
|
| 1607 |
|
| 1608 |
-
|
| 1609 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1610 |
try:
|
| 1611 |
-
|
| 1612 |
-
|
| 1613 |
-
|
| 1614 |
-
|
| 1615 |
-
|
| 1616 |
-
|
| 1617 |
-
|
| 1618 |
-
|
| 1619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1620 |
except Exception as hook_error:
|
| 1621 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 1622 |
|
|
@@ -2022,6 +2079,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2022 |
hooks = []
|
| 2023 |
|
| 2024 |
def make_qkv_hook(layer_idx):
|
|
|
|
| 2025 |
def hook(module, input, output):
|
| 2026 |
try:
|
| 2027 |
if output.dim() != 3:
|
|
@@ -2041,16 +2099,72 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2041 |
pass
|
| 2042 |
return hook
|
| 2043 |
|
| 2044 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2045 |
try:
|
| 2046 |
-
|
| 2047 |
-
|
| 2048 |
-
|
| 2049 |
-
|
| 2050 |
-
|
| 2051 |
-
|
| 2052 |
-
|
| 2053 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2054 |
except Exception as hook_error:
|
| 2055 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2056 |
|
|
|
|
| 1582 |
hooks = []
|
| 1583 |
|
| 1584 |
def make_qkv_hook(layer_idx):
|
| 1585 |
+
"""Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
|
| 1586 |
def hook(module, input, output):
|
| 1587 |
try:
|
| 1588 |
# output shape: [batch, seq_len, 3 * hidden_size]
|
|
|
|
| 1606 |
pass
|
| 1607 |
return hook
|
| 1608 |
|
| 1609 |
+
def make_separate_q_hook(layer_idx):
|
| 1610 |
+
"""Hook for separate Q projection (Mistral/LLaMA style)"""
|
| 1611 |
+
def hook(module, input, output):
|
| 1612 |
+
try:
|
| 1613 |
+
if layer_idx not in qkv_captures:
|
| 1614 |
+
qkv_captures[layer_idx] = {}
|
| 1615 |
+
# output shape: [batch, seq_len, num_heads * head_dim]
|
| 1616 |
+
qkv_captures[layer_idx]['q'] = output[0].detach().cpu()
|
| 1617 |
+
except Exception:
|
| 1618 |
+
pass
|
| 1619 |
+
return hook
|
| 1620 |
+
|
| 1621 |
+
def make_separate_k_hook(layer_idx):
|
| 1622 |
+
"""Hook for separate K projection (Mistral/LLaMA style)"""
|
| 1623 |
+
def hook(module, input, output):
|
| 1624 |
+
try:
|
| 1625 |
+
if layer_idx not in qkv_captures:
|
| 1626 |
+
qkv_captures[layer_idx] = {}
|
| 1627 |
+
qkv_captures[layer_idx]['k'] = output[0].detach().cpu()
|
| 1628 |
+
except Exception:
|
| 1629 |
+
pass
|
| 1630 |
+
return hook
|
| 1631 |
+
|
| 1632 |
+
def make_separate_v_hook(layer_idx):
|
| 1633 |
+
"""Hook for separate V projection (Mistral/LLaMA style)"""
|
| 1634 |
+
def hook(module, input, output):
|
| 1635 |
+
try:
|
| 1636 |
+
if layer_idx not in qkv_captures:
|
| 1637 |
+
qkv_captures[layer_idx] = {}
|
| 1638 |
+
qkv_captures[layer_idx]['v'] = output[0].detach().cpu()
|
| 1639 |
+
except Exception:
|
| 1640 |
+
pass
|
| 1641 |
+
return hook
|
| 1642 |
+
|
| 1643 |
+
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|
| 1644 |
try:
|
| 1645 |
+
# Try to get layers via adapter first (works for all model types)
|
| 1646 |
+
layers = None
|
| 1647 |
+
if manager.adapter:
|
| 1648 |
+
try:
|
| 1649 |
+
layers = manager.adapter._get_layers()
|
| 1650 |
+
logger.info(f"Using adapter to get {len(layers)} layers for QKV hooks")
|
| 1651 |
+
except Exception:
|
| 1652 |
+
pass
|
| 1653 |
+
|
| 1654 |
+
# Fallback for CodeGen if adapter doesn't work
|
| 1655 |
+
if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 1656 |
+
layers = manager.model.transformer.h
|
| 1657 |
+
logger.info(f"Using transformer.h for {len(layers)} layers for QKV hooks")
|
| 1658 |
+
|
| 1659 |
+
if layers:
|
| 1660 |
+
for layer_idx, layer in enumerate(layers):
|
| 1661 |
+
# Mistral/Devstral: separate Q, K, V projections
|
| 1662 |
+
if hasattr(layer, 'self_attn'):
|
| 1663 |
+
attn = layer.self_attn
|
| 1664 |
+
if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
|
| 1665 |
+
hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
|
| 1666 |
+
hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
|
| 1667 |
+
hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
|
| 1668 |
+
# CodeGen/GPT-NeoX: combined QKV projection
|
| 1669 |
+
elif hasattr(layer, 'attn'):
|
| 1670 |
+
if hasattr(layer.attn, 'qkv_proj'):
|
| 1671 |
+
hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 1672 |
+
elif hasattr(layer.attn, 'c_attn'):
|
| 1673 |
+
# GPT-2 style attention
|
| 1674 |
+
hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 1675 |
+
|
| 1676 |
+
logger.info(f"Registered {len(hooks)} QKV hooks")
|
| 1677 |
except Exception as hook_error:
|
| 1678 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 1679 |
|
|
|
|
| 2079 |
hooks = []
|
| 2080 |
|
| 2081 |
def make_qkv_hook(layer_idx):
|
| 2082 |
+
"""Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
|
| 2083 |
def hook(module, input, output):
|
| 2084 |
try:
|
| 2085 |
if output.dim() != 3:
|
|
|
|
| 2099 |
pass
|
| 2100 |
return hook
|
| 2101 |
|
| 2102 |
+
def make_separate_q_hook(layer_idx):
|
| 2103 |
+
"""Hook for separate Q projection (Mistral/LLaMA style)"""
|
| 2104 |
+
def hook(module, input, output):
|
| 2105 |
+
try:
|
| 2106 |
+
if layer_idx not in qkv_captures:
|
| 2107 |
+
qkv_captures[layer_idx] = {}
|
| 2108 |
+
qkv_captures[layer_idx]['q'] = output[0].detach().cpu()
|
| 2109 |
+
except Exception:
|
| 2110 |
+
pass
|
| 2111 |
+
return hook
|
| 2112 |
+
|
| 2113 |
+
def make_separate_k_hook(layer_idx):
|
| 2114 |
+
"""Hook for separate K projection (Mistral/LLaMA style)"""
|
| 2115 |
+
def hook(module, input, output):
|
| 2116 |
+
try:
|
| 2117 |
+
if layer_idx not in qkv_captures:
|
| 2118 |
+
qkv_captures[layer_idx] = {}
|
| 2119 |
+
qkv_captures[layer_idx]['k'] = output[0].detach().cpu()
|
| 2120 |
+
except Exception:
|
| 2121 |
+
pass
|
| 2122 |
+
return hook
|
| 2123 |
+
|
| 2124 |
+
def make_separate_v_hook(layer_idx):
|
| 2125 |
+
"""Hook for separate V projection (Mistral/LLaMA style)"""
|
| 2126 |
+
def hook(module, input, output):
|
| 2127 |
+
try:
|
| 2128 |
+
if layer_idx not in qkv_captures:
|
| 2129 |
+
qkv_captures[layer_idx] = {}
|
| 2130 |
+
qkv_captures[layer_idx]['v'] = output[0].detach().cpu()
|
| 2131 |
+
except Exception:
|
| 2132 |
+
pass
|
| 2133 |
+
return hook
|
| 2134 |
+
|
| 2135 |
+
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|
| 2136 |
try:
|
| 2137 |
+
# Try to get layers via adapter first (works for all model types)
|
| 2138 |
+
layers = None
|
| 2139 |
+
if manager.adapter:
|
| 2140 |
+
try:
|
| 2141 |
+
layers = manager.adapter._get_layers()
|
| 2142 |
+
logger.info(f"[Stream] Using adapter to get {len(layers)} layers for QKV hooks")
|
| 2143 |
+
except Exception:
|
| 2144 |
+
pass
|
| 2145 |
+
|
| 2146 |
+
# Fallback for CodeGen if adapter doesn't work
|
| 2147 |
+
if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 2148 |
+
layers = manager.model.transformer.h
|
| 2149 |
+
logger.info(f"[Stream] Using transformer.h for {len(layers)} layers for QKV hooks")
|
| 2150 |
+
|
| 2151 |
+
if layers:
|
| 2152 |
+
for layer_idx, layer in enumerate(layers):
|
| 2153 |
+
# Mistral/Devstral: separate Q, K, V projections
|
| 2154 |
+
if hasattr(layer, 'self_attn'):
|
| 2155 |
+
attn = layer.self_attn
|
| 2156 |
+
if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
|
| 2157 |
+
hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
|
| 2158 |
+
hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
|
| 2159 |
+
hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
|
| 2160 |
+
# CodeGen/GPT-NeoX: combined QKV projection
|
| 2161 |
+
elif hasattr(layer, 'attn'):
|
| 2162 |
+
if hasattr(layer.attn, 'qkv_proj'):
|
| 2163 |
+
hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 2164 |
+
elif hasattr(layer.attn, 'c_attn'):
|
| 2165 |
+
hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
|
| 2166 |
+
|
| 2167 |
+
logger.info(f"[Stream] Registered {len(hooks)} QKV hooks")
|
| 2168 |
except Exception as hook_error:
|
| 2169 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2170 |
|