gary-boon Claude Opus 4.5 commited on
Commit
9056859
·
1 Parent(s): 4ec134b

Fix QKV matrix extraction for Mistral/Devstral architecture

Browse files

- Use explicit head_dim from model config (Mistral uses 128, not hidden_size/num_heads=160)
- Handle both tuple and tensor outputs from projection hooks
- Properly reshape Q/K/V tensors to [seq_len, num_heads, head_dim]
- Support GQA by expanding K/V heads (8) to match Q heads (32)
- Add warning logs for debugging hook failures

This fixes the "too many indices for tensor of dimension 2" and
shape mismatch errors when extracting QKV matrices from Devstral.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>

Files changed (1) hide show
  1. backend/model_service.py +124 -20
backend/model_service.py CHANGED
@@ -1571,7 +1571,11 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1571
 
1572
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
1573
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
1574
- head_dim = d_model // n_heads
 
 
 
 
1575
 
1576
  # Generation loop with full instrumentation
1577
  layer_data_by_token = [] # Store layer data for each generated token
@@ -1612,10 +1616,22 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1612
  try:
1613
  if layer_idx not in qkv_captures:
1614
  qkv_captures[layer_idx] = {}
 
 
 
 
 
 
1615
  # output shape: [batch, seq_len, num_heads * head_dim]
1616
- qkv_captures[layer_idx]['q'] = output[0].detach().cpu()
1617
- except Exception:
1618
- pass
 
 
 
 
 
 
1619
  return hook
1620
 
1621
  def make_separate_k_hook(layer_idx):
@@ -1624,9 +1640,27 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1624
  try:
1625
  if layer_idx not in qkv_captures:
1626
  qkv_captures[layer_idx] = {}
1627
- qkv_captures[layer_idx]['k'] = output[0].detach().cpu()
1628
- except Exception:
1629
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1630
  return hook
1631
 
1632
  def make_separate_v_hook(layer_idx):
@@ -1635,9 +1669,27 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1635
  try:
1636
  if layer_idx not in qkv_captures:
1637
  qkv_captures[layer_idx] = {}
1638
- qkv_captures[layer_idx]['v'] = output[0].detach().cpu()
1639
- except Exception:
1640
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1641
  return hook
1642
 
1643
  # Register hooks - support both CodeGen and Mistral/Devstral architectures
@@ -2068,7 +2120,11 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2068
 
2069
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
2070
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
2071
- head_dim = d_model // n_heads
 
 
 
 
2072
 
2073
  # === STAGE 2: GENERATING ===
2074
  layer_data_by_token = []
@@ -2105,9 +2161,21 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2105
  try:
2106
  if layer_idx not in qkv_captures:
2107
  qkv_captures[layer_idx] = {}
2108
- qkv_captures[layer_idx]['q'] = output[0].detach().cpu()
2109
- except Exception:
2110
- pass
 
 
 
 
 
 
 
 
 
 
 
 
2111
  return hook
2112
 
2113
  def make_separate_k_hook(layer_idx):
@@ -2116,9 +2184,27 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2116
  try:
2117
  if layer_idx not in qkv_captures:
2118
  qkv_captures[layer_idx] = {}
2119
- qkv_captures[layer_idx]['k'] = output[0].detach().cpu()
2120
- except Exception:
2121
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2122
  return hook
2123
 
2124
  def make_separate_v_hook(layer_idx):
@@ -2127,9 +2213,27 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2127
  try:
2128
  if layer_idx not in qkv_captures:
2129
  qkv_captures[layer_idx] = {}
2130
- qkv_captures[layer_idx]['v'] = output[0].detach().cpu()
2131
- except Exception:
2132
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2133
  return hook
2134
 
2135
  # Register hooks - support both CodeGen and Mistral/Devstral architectures
 
1571
 
1572
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
1573
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
1574
+ # Use explicit head_dim from config if available (Mistral models have this)
1575
+ if hasattr(manager.model.config, 'head_dim'):
1576
+ head_dim = manager.model.config.head_dim
1577
+ else:
1578
+ head_dim = d_model // n_heads
1579
 
1580
  # Generation loop with full instrumentation
1581
  layer_data_by_token = [] # Store layer data for each generated token
 
1616
  try:
1617
  if layer_idx not in qkv_captures:
1618
  qkv_captures[layer_idx] = {}
1619
+ # Handle both tuple and tensor outputs
1620
+ if isinstance(output, tuple):
1621
+ out = output[0]
1622
+ else:
1623
+ out = output
1624
+ out = out.detach().cpu()
1625
  # output shape: [batch, seq_len, num_heads * head_dim]
1626
+ # If 3D, take first batch element
1627
+ if out.dim() == 3:
1628
+ out = out[0] # [seq_len, num_heads * head_dim]
1629
+ # Reshape to [seq_len, num_heads, head_dim]
1630
+ seq_len = out.shape[0]
1631
+ out = out.reshape(seq_len, n_heads, head_dim)
1632
+ qkv_captures[layer_idx]['q'] = out
1633
+ except Exception as e:
1634
+ logger.warning(f"Q hook error layer {layer_idx}: {e}")
1635
  return hook
1636
 
1637
  def make_separate_k_hook(layer_idx):
 
1640
  try:
1641
  if layer_idx not in qkv_captures:
1642
  qkv_captures[layer_idx] = {}
1643
+ # Handle both tuple and tensor outputs
1644
+ if isinstance(output, tuple):
1645
+ out = output[0]
1646
+ else:
1647
+ out = output
1648
+ out = out.detach().cpu()
1649
+ # If 3D, take first batch element
1650
+ if out.dim() == 3:
1651
+ out = out[0] # [seq_len, kv_heads * head_dim]
1652
+ # For GQA models, K has fewer heads (kv_heads)
1653
+ seq_len = out.shape[0]
1654
+ hidden_size = out.shape[1]
1655
+ actual_kv_heads = hidden_size // head_dim
1656
+ out = out.reshape(seq_len, actual_kv_heads, head_dim)
1657
+ # If GQA, repeat KV heads to match Q heads
1658
+ if actual_kv_heads != n_heads:
1659
+ repeat_factor = n_heads // actual_kv_heads
1660
+ out = out.repeat_interleave(repeat_factor, dim=1)
1661
+ qkv_captures[layer_idx]['k'] = out
1662
+ except Exception as e:
1663
+ logger.warning(f"K hook error layer {layer_idx}: {e}")
1664
  return hook
1665
 
1666
  def make_separate_v_hook(layer_idx):
 
1669
  try:
1670
  if layer_idx not in qkv_captures:
1671
  qkv_captures[layer_idx] = {}
1672
+ # Handle both tuple and tensor outputs
1673
+ if isinstance(output, tuple):
1674
+ out = output[0]
1675
+ else:
1676
+ out = output
1677
+ out = out.detach().cpu()
1678
+ # If 3D, take first batch element
1679
+ if out.dim() == 3:
1680
+ out = out[0] # [seq_len, kv_heads * head_dim]
1681
+ # For GQA models, V has fewer heads (kv_heads)
1682
+ seq_len = out.shape[0]
1683
+ hidden_size = out.shape[1]
1684
+ actual_kv_heads = hidden_size // head_dim
1685
+ out = out.reshape(seq_len, actual_kv_heads, head_dim)
1686
+ # If GQA, repeat KV heads to match Q heads
1687
+ if actual_kv_heads != n_heads:
1688
+ repeat_factor = n_heads // actual_kv_heads
1689
+ out = out.repeat_interleave(repeat_factor, dim=1)
1690
+ qkv_captures[layer_idx]['v'] = out
1691
+ except Exception as e:
1692
+ logger.warning(f"V hook error layer {layer_idx}: {e}")
1693
  return hook
1694
 
1695
  # Register hooks - support both CodeGen and Mistral/Devstral architectures
 
2120
 
2121
  n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads
2122
  d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size
2123
+ # Use explicit head_dim from config if available (Mistral models have this)
2124
+ if hasattr(manager.model.config, 'head_dim'):
2125
+ head_dim = manager.model.config.head_dim
2126
+ else:
2127
+ head_dim = d_model // n_heads
2128
 
2129
  # === STAGE 2: GENERATING ===
2130
  layer_data_by_token = []
 
2161
  try:
2162
  if layer_idx not in qkv_captures:
2163
  qkv_captures[layer_idx] = {}
2164
+ # Handle both tuple and tensor outputs
2165
+ if isinstance(output, tuple):
2166
+ out = output[0]
2167
+ else:
2168
+ out = output
2169
+ out = out.detach().cpu()
2170
+ # If 3D, take first batch element
2171
+ if out.dim() == 3:
2172
+ out = out[0] # [seq_len, num_heads * head_dim]
2173
+ # Reshape to [seq_len, num_heads, head_dim]
2174
+ seq_len = out.shape[0]
2175
+ out = out.reshape(seq_len, n_heads, head_dim)
2176
+ qkv_captures[layer_idx]['q'] = out
2177
+ except Exception as e:
2178
+ logger.warning(f"[Stream] Q hook error layer {layer_idx}: {e}")
2179
  return hook
2180
 
2181
  def make_separate_k_hook(layer_idx):
 
2184
  try:
2185
  if layer_idx not in qkv_captures:
2186
  qkv_captures[layer_idx] = {}
2187
+ # Handle both tuple and tensor outputs
2188
+ if isinstance(output, tuple):
2189
+ out = output[0]
2190
+ else:
2191
+ out = output
2192
+ out = out.detach().cpu()
2193
+ # If 3D, take first batch element
2194
+ if out.dim() == 3:
2195
+ out = out[0] # [seq_len, kv_heads * head_dim]
2196
+ # For GQA models, K has fewer heads (kv_heads)
2197
+ seq_len = out.shape[0]
2198
+ hidden_size = out.shape[1]
2199
+ actual_kv_heads = hidden_size // head_dim
2200
+ out = out.reshape(seq_len, actual_kv_heads, head_dim)
2201
+ # If GQA, repeat KV heads to match Q heads
2202
+ if actual_kv_heads != n_heads:
2203
+ repeat_factor = n_heads // actual_kv_heads
2204
+ out = out.repeat_interleave(repeat_factor, dim=1)
2205
+ qkv_captures[layer_idx]['k'] = out
2206
+ except Exception as e:
2207
+ logger.warning(f"[Stream] K hook error layer {layer_idx}: {e}")
2208
  return hook
2209
 
2210
  def make_separate_v_hook(layer_idx):
 
2213
  try:
2214
  if layer_idx not in qkv_captures:
2215
  qkv_captures[layer_idx] = {}
2216
+ # Handle both tuple and tensor outputs
2217
+ if isinstance(output, tuple):
2218
+ out = output[0]
2219
+ else:
2220
+ out = output
2221
+ out = out.detach().cpu()
2222
+ # If 3D, take first batch element
2223
+ if out.dim() == 3:
2224
+ out = out[0] # [seq_len, kv_heads * head_dim]
2225
+ # For GQA models, V has fewer heads (kv_heads)
2226
+ seq_len = out.shape[0]
2227
+ hidden_size = out.shape[1]
2228
+ actual_kv_heads = hidden_size // head_dim
2229
+ out = out.reshape(seq_len, actual_kv_heads, head_dim)
2230
+ # If GQA, repeat KV heads to match Q heads
2231
+ if actual_kv_heads != n_heads:
2232
+ repeat_factor = n_heads // actual_kv_heads
2233
+ out = out.repeat_interleave(repeat_factor, dim=1)
2234
+ qkv_captures[layer_idx]['v'] = out
2235
+ except Exception as e:
2236
+ logger.warning(f"[Stream] V hook error layer {layer_idx}: {e}")
2237
  return hook
2238
 
2239
  # Register hooks - support both CodeGen and Mistral/Devstral architectures