SDXL-LLM-Distill: LLM-Powered Visual Conditioning

This repository contains training code for a knowledge distillation project. It replaces the standard CLIP text encoders in Stable Diffusion XL models with an LLM (Specifically: Qwen3 4B in the PoC).

The Concept

Standard CLIP encoders have a limited context window. They often fail to understand complex spatial relationships or negative prompts. This project uses Qwen3 4B as a "student" model. It learns to mimic the output format of the SDXL CLIP models while keeping its own linguistic reasoning. The current implementation is a proof of concept trained for one epoch using 10,000 SFW image and caption pairs in 4-bit precision from the SPRIGHT-T2I dataset.

Note

This is a POC implementation, done on a limited dataset (10k image-caption pairs only), using training for one epoch only, using 4 bit precision. If this is of interest to the community, I will organize a larger-scale training run.

Improvements

Replacing CLIP with LLM-based conditioning is designed to uplift these areas:

  • Spatial Adherence: Improves understanding of position words like "left of", "inside", and "under".
  • Complex Negation: Handles negative statements more accurately.
  • Extended Context: Supports captions longer than the standard 77-token limit (up to 256 tokens).
  • Natural language prompting: Accoding to community experience, CLIP works well with tag based prompts. This solution enables the use of natural language prompts for SDXL.

Tradeoffs

  • LoRA Incompatibility: LoRAs without trigger words (such as sliders) should work. LoRAs that expect specific trigger words will likely break as the text encoder is entirely different.
  • Increased Memory Footprint: An LLM takes more memory than CLIP. Inference speed is slower due to the larger parameter count.

Architecture

The system does not use the final text output of the LLM. It extracts the hidden states from the last transformer layer. These states are high-dimensional vector representations that contain a semantic summary of the prompt.

The Perceiver Resampler

Because LLM hidden states are not compatible with the SDXL UNet, a Perceiver Resampler acts as a translation bridge. It resamples the irregular LLM data into fixed-size tensors.

Technical Specifications:

  • Layers: 4 layers of cross-attention and feed-forward blocks.
  • Latent Queries: 77 learnable query vectors.
  • Input Dimension: 3072 (Qwen3-4B hidden size).
  • Output Dimensions:
    • Prompt Embeddings: [1, 77, 2048]
    • Pooled Output: [1, 1280]
  • Total Parameters: ~278M parameters. Parameters are saved in full precision float (4 byte)

The Qwen LoRA

Fine-tuning the LLM with LoRA is necessary to shift its internal attention. While a base LLM is trained for text prediction, the LoRA helps it emphasize visual features like spatial coordinates and textures that the Resampler needs to translate for the UNet.

Reproduction

Training

The training process has two stages. This saves VRAM.

Training Process

The training uses knowledge distillation where the combined Qwen3+Resampler model acts as a "student," learning to mimic a "teacher"

Distillation Methodology

The original, frozen SDXL CLIP-L and CLIP-G encoders serve as the "teacher" model. Their outputs for the entire training dataset are pre-computed and cached to save VRAM during the main training loop. The "student" model, which learns to mimic the teacher, is a composite of the 4-bit quantized Qwen3-4B model, PEFT LoRA adapters applied to its attention blocks, and the Perceiver Resampler module.

Loss Function

The training objective is to minimize MSE between the student's predictions and the teacher's cached outputs. A separate MSE loss is calculated for the detailed prompt embeddings and the global pooled embedding; these two losses are summed to form the final loss for backpropagation.

Optimization

The Resampler and the Qwen LoRA adapters are trained simultaneously using the AdamW optimizer. To account for their different roles, a differential learning rate is applied, with the Resampler (trained from scratch) using a higher rate than the LoRA adapters (fine-tuning). To fit training on consumer GPUs, gradient accumulation is used to simulate a larger effective batch size. The entire process, including loss metrics and gradient norms, is tracked using Weights & Biases for analysis.

Stage 1: Teacher Caching

This script runs the frozen SDXL CLIP models and saves their output to disk.

python src/cache_targets.py --data_dir ./my_dataset

Stage 2: Distillation

This script trains the Qwen3 LoRA adapters and the Perceiver Resampler to match the cached CLIP outputs.

python src/train.py --batch_size 16 --epochs 1

Acknowledgements

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for molbal/qwen-clip-resampler-adapter

Base model

Qwen/Qwen3-4B-Base
Finetuned
Qwen/Qwen3-4B
Adapter
(19)
this model

Dataset used to train molbal/qwen-clip-resampler-adapter