Qwen3 Lambda Gates — Knowledge/Reasoning Disentanglement
Collection
Per-neuron sigmoid gates on Qwen3 FFN neurons to disentangle factual knowledge from reasoning. • 10 items • Updated
Per-neuron learnable sigmoid gates that classify FFN intermediate neurons in Qwen/Qwen3-0.6B into "knowledge" vs. "reasoning" neurons. At inference, gates below a chosen threshold can be zeroed to suppress factual recall while preserving reasoning skill.
This is the baseline checkpoint used as the primary comparison point in my lambda-gate experiments. Derived from KL-heavy distillation training.
| File | Description |
|---|---|
lambda_logits.pt |
Per-neuron logits (pre-sigmoid). Shape: (28, 3072) flattened. torch.load(...) returns a state dict with model.module.model.layers.<i>.mlp.lambda_logits keys. |
neuron_indices.json |
Extracted knowledge-neuron indices per layer at the default threshold 0.5: {"layer_idx": [neuron_idx, ...]}. |
gate_stats.json |
Gate statistics and five selected operating thresholds (5%, 25%, 50%, 75%, 95% off-fractions). |
selected_thresholds.txt |
Comma-separated threshold list matching gate_stats.json. |
| Metric | Value |
|---|---|
| Total gates | 86,016 (28 layers × 3072 intermediate) |
| Mean sigmoid gate | 0.501 |
| Std | 0.044 |
| Min / Max | 0.126 / 0.957 |
| Threshold | Neurons masked | Role |
|---|---|---|
| 0.432 | 5.0% | Very mild masking |
| 0.478 | 25.4% | Moderate |
| 0.502 | 50.1% | Median split |
| 0.525 | 74.6% | Heavy |
| 0.568 | 94.9% | Near-total |
Qwen/Qwen3-0.6Bsoftplus(β − entity_CE) on forget + KL(base || gated) on reasoningβ=4.0, λ_f=0.1, λ_r=0.5, distill T=2.0, forget_retain_ratio=1:2, lr=1e-2, cosine, 3 epochsimport torch, json
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)
# Load gate logits
gate_state = torch.load("lambda_logits.pt", map_location="cpu")
# Keys look like: module.model.layers.<i>.mlp.lambda_logits
# Option 1: Hard masking via neuron_indices at a chosen threshold
with open("neuron_indices.json") as f:
knowledge_neurons = json.load(f) # {"layer_idx": [n_idx, ...]}
for layer_idx, neuron_idxs in knowledge_neurons.items():
layer = model.model.layers[int(layer_idx)]
with torch.no_grad():
layer.mlp.gate_proj.weight[neuron_idxs, :] = 0
layer.mlp.up_proj.weight[neuron_idxs, :] = 0
# Option 2: Soft gating at inference via sigmoid(logits)
# See https://github.com/... for the LambdaGatedFFN wrapper used in training.
If you use these gates, please cite the Qwen3 model card as well. This checkpoint was produced as part of a research project on knowledge–reasoning disentanglement in LLMs.