gary-boon Claude Opus 4.5 commited on
Commit
4ec134b
·
1 Parent(s): 3e67ea2

Fix QKV visualization for Mistral/Devstral architecture

Browse files

- Add separate Q, K, V hook makers for Mistral-style architectures
that use separate projections (q_proj, k_proj, v_proj)
- Update layer iteration to use adapter's _get_layers() method
for architecture-agnostic layer access
- Maintain backwards compatibility with CodeGen's combined QKV projection
- Apply same fix to both regular and streaming endpoints

This fixes the QKV visualization which was silently failing for Devstral
because the hook registration only supported CodeGen's transformer.h structure.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>

Files changed (1) hide show
  1. backend/model_service.py +134 -20
backend/model_service.py CHANGED
@@ -1582,6 +1582,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1582
  hooks = []
1583
 
1584
  def make_qkv_hook(layer_idx):
 
1585
  def hook(module, input, output):
1586
  try:
1587
  # output shape: [batch, seq_len, 3 * hidden_size]
@@ -1605,18 +1606,74 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1605
  pass
1606
  return hook
1607
 
1608
- # Register hooks on all qkv_proj modules (if available)
1609
- # This is model-specific - CodeGen uses different architecture
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1610
  try:
1611
- if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
1612
- for layer_idx, layer in enumerate(manager.model.transformer.h):
1613
- if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
1614
- hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
1615
- hooks.append(hook)
1616
- elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
1617
- # GPT-2 style attention
1618
- hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
1619
- hooks.append(hook)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1620
  except Exception as hook_error:
1621
  logger.warning(f"Could not register QKV hooks: {hook_error}")
1622
 
@@ -2022,6 +2079,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2022
  hooks = []
2023
 
2024
  def make_qkv_hook(layer_idx):
 
2025
  def hook(module, input, output):
2026
  try:
2027
  if output.dim() != 3:
@@ -2041,16 +2099,72 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2041
  pass
2042
  return hook
2043
 
2044
- # Register hooks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2045
  try:
2046
- if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2047
- for layer_idx, layer in enumerate(manager.model.transformer.h):
2048
- if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
2049
- hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
2050
- hooks.append(hook)
2051
- elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
2052
- hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
2053
- hooks.append(hook)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2054
  except Exception as hook_error:
2055
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2056
 
 
1582
  hooks = []
1583
 
1584
  def make_qkv_hook(layer_idx):
1585
+ """Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
1586
  def hook(module, input, output):
1587
  try:
1588
  # output shape: [batch, seq_len, 3 * hidden_size]
 
1606
  pass
1607
  return hook
1608
 
1609
+ def make_separate_q_hook(layer_idx):
1610
+ """Hook for separate Q projection (Mistral/LLaMA style)"""
1611
+ def hook(module, input, output):
1612
+ try:
1613
+ if layer_idx not in qkv_captures:
1614
+ qkv_captures[layer_idx] = {}
1615
+ # output shape: [batch, seq_len, num_heads * head_dim]
1616
+ qkv_captures[layer_idx]['q'] = output[0].detach().cpu()
1617
+ except Exception:
1618
+ pass
1619
+ return hook
1620
+
1621
+ def make_separate_k_hook(layer_idx):
1622
+ """Hook for separate K projection (Mistral/LLaMA style)"""
1623
+ def hook(module, input, output):
1624
+ try:
1625
+ if layer_idx not in qkv_captures:
1626
+ qkv_captures[layer_idx] = {}
1627
+ qkv_captures[layer_idx]['k'] = output[0].detach().cpu()
1628
+ except Exception:
1629
+ pass
1630
+ return hook
1631
+
1632
+ def make_separate_v_hook(layer_idx):
1633
+ """Hook for separate V projection (Mistral/LLaMA style)"""
1634
+ def hook(module, input, output):
1635
+ try:
1636
+ if layer_idx not in qkv_captures:
1637
+ qkv_captures[layer_idx] = {}
1638
+ qkv_captures[layer_idx]['v'] = output[0].detach().cpu()
1639
+ except Exception:
1640
+ pass
1641
+ return hook
1642
+
1643
+ # Register hooks - support both CodeGen and Mistral/Devstral architectures
1644
  try:
1645
+ # Try to get layers via adapter first (works for all model types)
1646
+ layers = None
1647
+ if manager.adapter:
1648
+ try:
1649
+ layers = manager.adapter._get_layers()
1650
+ logger.info(f"Using adapter to get {len(layers)} layers for QKV hooks")
1651
+ except Exception:
1652
+ pass
1653
+
1654
+ # Fallback for CodeGen if adapter doesn't work
1655
+ if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
1656
+ layers = manager.model.transformer.h
1657
+ logger.info(f"Using transformer.h for {len(layers)} layers for QKV hooks")
1658
+
1659
+ if layers:
1660
+ for layer_idx, layer in enumerate(layers):
1661
+ # Mistral/Devstral: separate Q, K, V projections
1662
+ if hasattr(layer, 'self_attn'):
1663
+ attn = layer.self_attn
1664
+ if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
1665
+ hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
1666
+ hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
1667
+ hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
1668
+ # CodeGen/GPT-NeoX: combined QKV projection
1669
+ elif hasattr(layer, 'attn'):
1670
+ if hasattr(layer.attn, 'qkv_proj'):
1671
+ hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
1672
+ elif hasattr(layer.attn, 'c_attn'):
1673
+ # GPT-2 style attention
1674
+ hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
1675
+
1676
+ logger.info(f"Registered {len(hooks)} QKV hooks")
1677
  except Exception as hook_error:
1678
  logger.warning(f"Could not register QKV hooks: {hook_error}")
1679
 
 
2079
  hooks = []
2080
 
2081
  def make_qkv_hook(layer_idx):
2082
+ """Hook for combined QKV projection (CodeGen/GPT-NeoX style)"""
2083
  def hook(module, input, output):
2084
  try:
2085
  if output.dim() != 3:
 
2099
  pass
2100
  return hook
2101
 
2102
+ def make_separate_q_hook(layer_idx):
2103
+ """Hook for separate Q projection (Mistral/LLaMA style)"""
2104
+ def hook(module, input, output):
2105
+ try:
2106
+ if layer_idx not in qkv_captures:
2107
+ qkv_captures[layer_idx] = {}
2108
+ qkv_captures[layer_idx]['q'] = output[0].detach().cpu()
2109
+ except Exception:
2110
+ pass
2111
+ return hook
2112
+
2113
+ def make_separate_k_hook(layer_idx):
2114
+ """Hook for separate K projection (Mistral/LLaMA style)"""
2115
+ def hook(module, input, output):
2116
+ try:
2117
+ if layer_idx not in qkv_captures:
2118
+ qkv_captures[layer_idx] = {}
2119
+ qkv_captures[layer_idx]['k'] = output[0].detach().cpu()
2120
+ except Exception:
2121
+ pass
2122
+ return hook
2123
+
2124
+ def make_separate_v_hook(layer_idx):
2125
+ """Hook for separate V projection (Mistral/LLaMA style)"""
2126
+ def hook(module, input, output):
2127
+ try:
2128
+ if layer_idx not in qkv_captures:
2129
+ qkv_captures[layer_idx] = {}
2130
+ qkv_captures[layer_idx]['v'] = output[0].detach().cpu()
2131
+ except Exception:
2132
+ pass
2133
+ return hook
2134
+
2135
+ # Register hooks - support both CodeGen and Mistral/Devstral architectures
2136
  try:
2137
+ # Try to get layers via adapter first (works for all model types)
2138
+ layers = None
2139
+ if manager.adapter:
2140
+ try:
2141
+ layers = manager.adapter._get_layers()
2142
+ logger.info(f"[Stream] Using adapter to get {len(layers)} layers for QKV hooks")
2143
+ except Exception:
2144
+ pass
2145
+
2146
+ # Fallback for CodeGen if adapter doesn't work
2147
+ if layers is None and hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2148
+ layers = manager.model.transformer.h
2149
+ logger.info(f"[Stream] Using transformer.h for {len(layers)} layers for QKV hooks")
2150
+
2151
+ if layers:
2152
+ for layer_idx, layer in enumerate(layers):
2153
+ # Mistral/Devstral: separate Q, K, V projections
2154
+ if hasattr(layer, 'self_attn'):
2155
+ attn = layer.self_attn
2156
+ if hasattr(attn, 'q_proj') and hasattr(attn, 'k_proj') and hasattr(attn, 'v_proj'):
2157
+ hooks.append(attn.q_proj.register_forward_hook(make_separate_q_hook(layer_idx)))
2158
+ hooks.append(attn.k_proj.register_forward_hook(make_separate_k_hook(layer_idx)))
2159
+ hooks.append(attn.v_proj.register_forward_hook(make_separate_v_hook(layer_idx)))
2160
+ # CodeGen/GPT-NeoX: combined QKV projection
2161
+ elif hasattr(layer, 'attn'):
2162
+ if hasattr(layer.attn, 'qkv_proj'):
2163
+ hooks.append(layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)))
2164
+ elif hasattr(layer.attn, 'c_attn'):
2165
+ hooks.append(layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx)))
2166
+
2167
+ logger.info(f"[Stream] Registered {len(hooks)} QKV hooks")
2168
  except Exception as hook_error:
2169
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2170