Qwen3-0.6B Lambda Gates — Baseline (KL-Heavy)

Per-neuron learnable sigmoid gates that classify FFN intermediate neurons in Qwen/Qwen3-0.6B into "knowledge" vs. "reasoning" neurons. At inference, gates below a chosen threshold can be zeroed to suppress factual recall while preserving reasoning skill.

This is the baseline checkpoint used as the primary comparison point in my lambda-gate experiments. Derived from KL-heavy distillation training.

Contents

File Description
lambda_logits.pt Per-neuron logits (pre-sigmoid). Shape: (28, 3072) flattened. torch.load(...) returns a state dict with model.module.model.layers.<i>.mlp.lambda_logits keys.
neuron_indices.json Extracted knowledge-neuron indices per layer at the default threshold 0.5: {"layer_idx": [neuron_idx, ...]}.
gate_stats.json Gate statistics and five selected operating thresholds (5%, 25%, 50%, 75%, 95% off-fractions).
selected_thresholds.txt Comma-separated threshold list matching gate_stats.json.

Gate Statistics

Metric Value
Total gates 86,016 (28 layers × 3072 intermediate)
Mean sigmoid gate 0.501
Std 0.044
Min / Max 0.126 / 0.957

Selected thresholds

Threshold Neurons masked Role
0.432 5.0% Very mild masking
0.478 25.4% Moderate
0.502 50.1% Median split
0.525 74.6% Heavy
0.568 94.9% Near-total

Training Setup

  • Base model: Qwen/Qwen3-0.6B
  • Forget data: PopQA-mini entity-masked knowledge text
  • Reasoning data: NuminaMath-CoT (10k seed subset)
  • Loss: softplus(β − entity_CE) on forget + KL(base || gated) on reasoning
  • Hyperparameters: β=4.0, λ_f=0.1, λ_r=0.5, distill T=2.0, forget_retain_ratio=1:2, lr=1e-2, cosine, 3 epochs
  • Hardware: 8×RTX 3090, DDP

Usage

import torch, json
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)

# Load gate logits
gate_state = torch.load("lambda_logits.pt", map_location="cpu")
# Keys look like: module.model.layers.<i>.mlp.lambda_logits

# Option 1: Hard masking via neuron_indices at a chosen threshold
with open("neuron_indices.json") as f:
    knowledge_neurons = json.load(f)  # {"layer_idx": [n_idx, ...]}

for layer_idx, neuron_idxs in knowledge_neurons.items():
    layer = model.model.layers[int(layer_idx)]
    with torch.no_grad():
        layer.mlp.gate_proj.weight[neuron_idxs, :] = 0
        layer.mlp.up_proj.weight[neuron_idxs, :] = 0

# Option 2: Soft gating at inference via sigmoid(logits)
# See https://github.com/... for the LambdaGatedFFN wrapper used in training.

Related Checkpoints

Citation

If you use these gates, please cite the Qwen3 model card as well. This checkpoint was produced as part of a research project on knowledge–reasoning disentanglement in LLMs.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hyunseoki/qwen3-0.6b-lambda-gates-baseline

Finetuned
Qwen/Qwen3-0.6B
Finetuned
(827)
this model

Collection including hyunseoki/qwen3-0.6b-lambda-gates-baseline