SDXL-LLM-Distill: LLM-Powered Visual Conditioning
This repository contains training code for a knowledge distillation project. It replaces the standard CLIP text encoders in Stable Diffusion XL models with an LLM (Specifically: Qwen3 4B in the PoC).
The Concept
Standard CLIP encoders have a limited context window. They often fail to understand complex spatial relationships or negative prompts. This project uses Qwen3 4B as a "student" model. It learns to mimic the output format of the SDXL CLIP models while keeping its own linguistic reasoning. The current implementation is a proof of concept trained for one epoch using 10,000 SFW image and caption pairs in 4-bit precision from the SPRIGHT-T2I dataset.
Note
This is a POC implementation, done on a limited dataset (10k image-caption pairs only), using training for one epoch only, using 4 bit precision. If this is of interest to the community, I will organize a larger-scale training run.
Improvements
Replacing CLIP with LLM-based conditioning is designed to uplift these areas:
- Spatial Adherence: Improves understanding of position words like "left of", "inside", and "under".
- Complex Negation: Handles negative statements more accurately.
- Extended Context: Supports captions longer than the standard 77-token limit (up to 256 tokens).
- Natural language prompting: Accoding to community experience, CLIP works well with tag based prompts. This solution enables the use of natural language prompts for SDXL.
Tradeoffs
- LoRA Incompatibility: LoRAs without trigger words (such as sliders) should work. LoRAs that expect specific trigger words will likely break as the text encoder is entirely different.
- Increased Memory Footprint: An LLM takes more memory than CLIP. Inference speed is slower due to the larger parameter count.
Architecture
The system does not use the final text output of the LLM. It extracts the hidden states from the last transformer layer. These states are high-dimensional vector representations that contain a semantic summary of the prompt.
The Perceiver Resampler
Because LLM hidden states are not compatible with the SDXL UNet, a Perceiver Resampler acts as a translation bridge. It resamples the irregular LLM data into fixed-size tensors.
Technical Specifications:
- Layers: 4 layers of cross-attention and feed-forward blocks.
- Latent Queries: 77 learnable query vectors.
- Input Dimension: 3072 (Qwen3-4B hidden size).
- Output Dimensions:
- Prompt Embeddings:
[1, 77, 2048] - Pooled Output:
[1, 1280]
- Prompt Embeddings:
- Total Parameters: ~278M parameters. Parameters are saved in full precision float (4 byte)
The Qwen LoRA
Fine-tuning the LLM with LoRA is necessary to shift its internal attention. While a base LLM is trained for text prediction, the LoRA helps it emphasize visual features like spatial coordinates and textures that the Resampler needs to translate for the UNet.
Reproduction
Training
The training process has two stages. This saves VRAM.
Training Process
The training uses knowledge distillation where the combined Qwen3+Resampler model acts as a "student," learning to mimic a "teacher"
Distillation Methodology
The original, frozen SDXL CLIP-L and CLIP-G encoders serve as the "teacher" model. Their outputs for the entire training dataset are pre-computed and cached to save VRAM during the main training loop. The "student" model, which learns to mimic the teacher, is a composite of the 4-bit quantized Qwen3-4B model, PEFT LoRA adapters applied to its attention blocks, and the Perceiver Resampler module.
Loss Function
The training objective is to minimize MSE between the student's predictions and the teacher's cached outputs. A separate MSE loss is calculated for the detailed prompt embeddings and the global pooled embedding; these two losses are summed to form the final loss for backpropagation.
Optimization
The Resampler and the Qwen LoRA adapters are trained simultaneously using the AdamW optimizer. To account for their different roles, a differential learning rate is applied, with the Resampler (trained from scratch) using a higher rate than the LoRA adapters (fine-tuning). To fit training on consumer GPUs, gradient accumulation is used to simulate a larger effective batch size. The entire process, including loss metrics and gradient norms, is tracked using Weights & Biases for analysis.
Stage 1: Teacher Caching
This script runs the frozen SDXL CLIP models and saves their output to disk.
python src/cache_targets.py --data_dir ./my_dataset
Stage 2: Distillation
This script trains the Qwen3 LoRA adapters and the Perceiver Resampler to match the cached CLIP outputs.
python src/train.py --batch_size 16 --epochs 1