Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
9056859
1
Parent(s):
4ec134b
Fix QKV matrix extraction for Mistral/Devstral architecture
Browse files- Use explicit head_dim from model config (Mistral uses 128, not hidden_size/num_heads=160)
- Handle both tuple and tensor outputs from projection hooks
- Properly reshape Q/K/V tensors to [seq_len, num_heads, head_dim]
- Support GQA by expanding K/V heads (8) to match Q heads (32)
- Add warning logs for debugging hook failures
This fixes the "too many indices for tensor of dimension 2" and
shape mismatch errors when extracting QKV matrices from Devstral.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
- backend/model_service.py +124 -20
backend/model_service.py
CHANGED
|
@@ -1571,7 +1571,11 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1571 |
|
| 1572 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 1573 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 1574 |
-
head_dim
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1575 |
|
| 1576 |
# Generation loop with full instrumentation
|
| 1577 |
layer_data_by_token = [] # Store layer data for each generated token
|
|
@@ -1612,10 +1616,22 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1612 |
try:
|
| 1613 |
if layer_idx not in qkv_captures:
|
| 1614 |
qkv_captures[layer_idx] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1615 |
# output shape: [batch, seq_len, num_heads * head_dim]
|
| 1616 |
-
|
| 1617 |
-
|
| 1618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1619 |
return hook
|
| 1620 |
|
| 1621 |
def make_separate_k_hook(layer_idx):
|
|
@@ -1624,9 +1640,27 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1624 |
try:
|
| 1625 |
if layer_idx not in qkv_captures:
|
| 1626 |
qkv_captures[layer_idx] = {}
|
| 1627 |
-
|
| 1628 |
-
|
| 1629 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1630 |
return hook
|
| 1631 |
|
| 1632 |
def make_separate_v_hook(layer_idx):
|
|
@@ -1635,9 +1669,27 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1635 |
try:
|
| 1636 |
if layer_idx not in qkv_captures:
|
| 1637 |
qkv_captures[layer_idx] = {}
|
| 1638 |
-
|
| 1639 |
-
|
| 1640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1641 |
return hook
|
| 1642 |
|
| 1643 |
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|
|
@@ -2068,7 +2120,11 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2068 |
|
| 2069 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 2070 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 2071 |
-
head_dim
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2072 |
|
| 2073 |
# === STAGE 2: GENERATING ===
|
| 2074 |
layer_data_by_token = []
|
|
@@ -2105,9 +2161,21 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2105 |
try:
|
| 2106 |
if layer_idx not in qkv_captures:
|
| 2107 |
qkv_captures[layer_idx] = {}
|
| 2108 |
-
|
| 2109 |
-
|
| 2110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2111 |
return hook
|
| 2112 |
|
| 2113 |
def make_separate_k_hook(layer_idx):
|
|
@@ -2116,9 +2184,27 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2116 |
try:
|
| 2117 |
if layer_idx not in qkv_captures:
|
| 2118 |
qkv_captures[layer_idx] = {}
|
| 2119 |
-
|
| 2120 |
-
|
| 2121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2122 |
return hook
|
| 2123 |
|
| 2124 |
def make_separate_v_hook(layer_idx):
|
|
@@ -2127,9 +2213,27 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2127 |
try:
|
| 2128 |
if layer_idx not in qkv_captures:
|
| 2129 |
qkv_captures[layer_idx] = {}
|
| 2130 |
-
|
| 2131 |
-
|
| 2132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2133 |
return hook
|
| 2134 |
|
| 2135 |
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|
|
|
|
| 1571 |
|
| 1572 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 1573 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 1574 |
+
# Use explicit head_dim from config if available (Mistral models have this)
|
| 1575 |
+
if hasattr(manager.model.config, 'head_dim'):
|
| 1576 |
+
head_dim = manager.model.config.head_dim
|
| 1577 |
+
else:
|
| 1578 |
+
head_dim = d_model // n_heads
|
| 1579 |
|
| 1580 |
# Generation loop with full instrumentation
|
| 1581 |
layer_data_by_token = [] # Store layer data for each generated token
|
|
|
|
| 1616 |
try:
|
| 1617 |
if layer_idx not in qkv_captures:
|
| 1618 |
qkv_captures[layer_idx] = {}
|
| 1619 |
+
# Handle both tuple and tensor outputs
|
| 1620 |
+
if isinstance(output, tuple):
|
| 1621 |
+
out = output[0]
|
| 1622 |
+
else:
|
| 1623 |
+
out = output
|
| 1624 |
+
out = out.detach().cpu()
|
| 1625 |
# output shape: [batch, seq_len, num_heads * head_dim]
|
| 1626 |
+
# If 3D, take first batch element
|
| 1627 |
+
if out.dim() == 3:
|
| 1628 |
+
out = out[0] # [seq_len, num_heads * head_dim]
|
| 1629 |
+
# Reshape to [seq_len, num_heads, head_dim]
|
| 1630 |
+
seq_len = out.shape[0]
|
| 1631 |
+
out = out.reshape(seq_len, n_heads, head_dim)
|
| 1632 |
+
qkv_captures[layer_idx]['q'] = out
|
| 1633 |
+
except Exception as e:
|
| 1634 |
+
logger.warning(f"Q hook error layer {layer_idx}: {e}")
|
| 1635 |
return hook
|
| 1636 |
|
| 1637 |
def make_separate_k_hook(layer_idx):
|
|
|
|
| 1640 |
try:
|
| 1641 |
if layer_idx not in qkv_captures:
|
| 1642 |
qkv_captures[layer_idx] = {}
|
| 1643 |
+
# Handle both tuple and tensor outputs
|
| 1644 |
+
if isinstance(output, tuple):
|
| 1645 |
+
out = output[0]
|
| 1646 |
+
else:
|
| 1647 |
+
out = output
|
| 1648 |
+
out = out.detach().cpu()
|
| 1649 |
+
# If 3D, take first batch element
|
| 1650 |
+
if out.dim() == 3:
|
| 1651 |
+
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 1652 |
+
# For GQA models, K has fewer heads (kv_heads)
|
| 1653 |
+
seq_len = out.shape[0]
|
| 1654 |
+
hidden_size = out.shape[1]
|
| 1655 |
+
actual_kv_heads = hidden_size // head_dim
|
| 1656 |
+
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 1657 |
+
# If GQA, repeat KV heads to match Q heads
|
| 1658 |
+
if actual_kv_heads != n_heads:
|
| 1659 |
+
repeat_factor = n_heads // actual_kv_heads
|
| 1660 |
+
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 1661 |
+
qkv_captures[layer_idx]['k'] = out
|
| 1662 |
+
except Exception as e:
|
| 1663 |
+
logger.warning(f"K hook error layer {layer_idx}: {e}")
|
| 1664 |
return hook
|
| 1665 |
|
| 1666 |
def make_separate_v_hook(layer_idx):
|
|
|
|
| 1669 |
try:
|
| 1670 |
if layer_idx not in qkv_captures:
|
| 1671 |
qkv_captures[layer_idx] = {}
|
| 1672 |
+
# Handle both tuple and tensor outputs
|
| 1673 |
+
if isinstance(output, tuple):
|
| 1674 |
+
out = output[0]
|
| 1675 |
+
else:
|
| 1676 |
+
out = output
|
| 1677 |
+
out = out.detach().cpu()
|
| 1678 |
+
# If 3D, take first batch element
|
| 1679 |
+
if out.dim() == 3:
|
| 1680 |
+
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 1681 |
+
# For GQA models, V has fewer heads (kv_heads)
|
| 1682 |
+
seq_len = out.shape[0]
|
| 1683 |
+
hidden_size = out.shape[1]
|
| 1684 |
+
actual_kv_heads = hidden_size // head_dim
|
| 1685 |
+
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 1686 |
+
# If GQA, repeat KV heads to match Q heads
|
| 1687 |
+
if actual_kv_heads != n_heads:
|
| 1688 |
+
repeat_factor = n_heads // actual_kv_heads
|
| 1689 |
+
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 1690 |
+
qkv_captures[layer_idx]['v'] = out
|
| 1691 |
+
except Exception as e:
|
| 1692 |
+
logger.warning(f"V hook error layer {layer_idx}: {e}")
|
| 1693 |
return hook
|
| 1694 |
|
| 1695 |
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|
|
|
|
| 2120 |
|
| 2121 |
n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
|
| 2122 |
d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
|
| 2123 |
+
# Use explicit head_dim from config if available (Mistral models have this)
|
| 2124 |
+
if hasattr(manager.model.config, 'head_dim'):
|
| 2125 |
+
head_dim = manager.model.config.head_dim
|
| 2126 |
+
else:
|
| 2127 |
+
head_dim = d_model // n_heads
|
| 2128 |
|
| 2129 |
# === STAGE 2: GENERATING ===
|
| 2130 |
layer_data_by_token = []
|
|
|
|
| 2161 |
try:
|
| 2162 |
if layer_idx not in qkv_captures:
|
| 2163 |
qkv_captures[layer_idx] = {}
|
| 2164 |
+
# Handle both tuple and tensor outputs
|
| 2165 |
+
if isinstance(output, tuple):
|
| 2166 |
+
out = output[0]
|
| 2167 |
+
else:
|
| 2168 |
+
out = output
|
| 2169 |
+
out = out.detach().cpu()
|
| 2170 |
+
# If 3D, take first batch element
|
| 2171 |
+
if out.dim() == 3:
|
| 2172 |
+
out = out[0] # [seq_len, num_heads * head_dim]
|
| 2173 |
+
# Reshape to [seq_len, num_heads, head_dim]
|
| 2174 |
+
seq_len = out.shape[0]
|
| 2175 |
+
out = out.reshape(seq_len, n_heads, head_dim)
|
| 2176 |
+
qkv_captures[layer_idx]['q'] = out
|
| 2177 |
+
except Exception as e:
|
| 2178 |
+
logger.warning(f"[Stream] Q hook error layer {layer_idx}: {e}")
|
| 2179 |
return hook
|
| 2180 |
|
| 2181 |
def make_separate_k_hook(layer_idx):
|
|
|
|
| 2184 |
try:
|
| 2185 |
if layer_idx not in qkv_captures:
|
| 2186 |
qkv_captures[layer_idx] = {}
|
| 2187 |
+
# Handle both tuple and tensor outputs
|
| 2188 |
+
if isinstance(output, tuple):
|
| 2189 |
+
out = output[0]
|
| 2190 |
+
else:
|
| 2191 |
+
out = output
|
| 2192 |
+
out = out.detach().cpu()
|
| 2193 |
+
# If 3D, take first batch element
|
| 2194 |
+
if out.dim() == 3:
|
| 2195 |
+
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 2196 |
+
# For GQA models, K has fewer heads (kv_heads)
|
| 2197 |
+
seq_len = out.shape[0]
|
| 2198 |
+
hidden_size = out.shape[1]
|
| 2199 |
+
actual_kv_heads = hidden_size // head_dim
|
| 2200 |
+
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 2201 |
+
# If GQA, repeat KV heads to match Q heads
|
| 2202 |
+
if actual_kv_heads != n_heads:
|
| 2203 |
+
repeat_factor = n_heads // actual_kv_heads
|
| 2204 |
+
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 2205 |
+
qkv_captures[layer_idx]['k'] = out
|
| 2206 |
+
except Exception as e:
|
| 2207 |
+
logger.warning(f"[Stream] K hook error layer {layer_idx}: {e}")
|
| 2208 |
return hook
|
| 2209 |
|
| 2210 |
def make_separate_v_hook(layer_idx):
|
|
|
|
| 2213 |
try:
|
| 2214 |
if layer_idx not in qkv_captures:
|
| 2215 |
qkv_captures[layer_idx] = {}
|
| 2216 |
+
# Handle both tuple and tensor outputs
|
| 2217 |
+
if isinstance(output, tuple):
|
| 2218 |
+
out = output[0]
|
| 2219 |
+
else:
|
| 2220 |
+
out = output
|
| 2221 |
+
out = out.detach().cpu()
|
| 2222 |
+
# If 3D, take first batch element
|
| 2223 |
+
if out.dim() == 3:
|
| 2224 |
+
out = out[0] # [seq_len, kv_heads * head_dim]
|
| 2225 |
+
# For GQA models, V has fewer heads (kv_heads)
|
| 2226 |
+
seq_len = out.shape[0]
|
| 2227 |
+
hidden_size = out.shape[1]
|
| 2228 |
+
actual_kv_heads = hidden_size // head_dim
|
| 2229 |
+
out = out.reshape(seq_len, actual_kv_heads, head_dim)
|
| 2230 |
+
# If GQA, repeat KV heads to match Q heads
|
| 2231 |
+
if actual_kv_heads != n_heads:
|
| 2232 |
+
repeat_factor = n_heads // actual_kv_heads
|
| 2233 |
+
out = out.repeat_interleave(repeat_factor, dim=1)
|
| 2234 |
+
qkv_captures[layer_idx]['v'] = out
|
| 2235 |
+
except Exception as e:
|
| 2236 |
+
logger.warning(f"[Stream] V hook error layer {layer_idx}: {e}")
|
| 2237 |
return hook
|
| 2238 |
|
| 2239 |
# Register hooks - support both CodeGen and Mistral/Devstral architectures
|