Instructions to use neuralvfx/LibreFlux-IP-Adapter with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use neuralvfx/LibreFlux-IP-Adapter with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("neuralvfx/LibreFlux-IP-Adapter", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- Draw Things
- DiffusionBee
| from itertools import chain | |
| import torch | |
| from torch import nn | |
| from diffusers.models.attention_processor import ( | |
| Attention, | |
| AttentionProcessor, | |
| ) | |
| from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel | |
| import torch.nn.functional as F | |
| from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers | |
| from diffusers.models.attention_processor import Attention | |
| import inspect | |
| from functools import partial | |
| from diffusers.models.normalization import RMSNorm | |
| from typing import Any, Dict, List, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| class IPFluxAttnProcessor2_0(nn.Module): | |
| """Attention processor used typically in processing the SD3-like self-attention projections.""" | |
| def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| self.scale = scale | |
| self.num_tokens = num_tokens | |
| self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
| self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) | |
| self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False) | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: torch.FloatTensor = None, | |
| ip_encoder_hidden_states: torch.FloatTensor = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| layer_scale: Optional[torch.Tensor] = None, | |
| ) -> torch.FloatTensor: | |
| batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ip_hidden_states = ip_encoder_hidden_states | |
| # `sample` projections. | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # handle IP attention FIRST | |
| # for ip-adapter | |
| if ip_hidden_states != None: | |
| ip_key = self.to_k_ip(ip_hidden_states) | |
| ip_value = self.to_v_ip(ip_hidden_states) | |
| # reshaping to match query shape | |
| ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| ip_key = self.norm_added_k(ip_key) | |
| # Using flux stype attention here | |
| ip_hidden_states = F.scaled_dot_product_attention( | |
| query, | |
| ip_key, | |
| ip_value, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| attn_mask=None, | |
| ) | |
| # reshaping ip_hidden_states in the same way as hidden_states | |
| ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` | |
| if encoder_hidden_states is not None: | |
| # `context` projections. | |
| encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) | |
| encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
| encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
| encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( | |
| batch_size, -1, attn.heads, head_dim | |
| ).transpose(1, 2) | |
| if attn.norm_added_q is not None: | |
| encoder_hidden_states_query_proj = attn.norm_added_q( | |
| encoder_hidden_states_query_proj | |
| ) | |
| if attn.norm_added_k is not None: | |
| encoder_hidden_states_key_proj = attn.norm_added_k( | |
| encoder_hidden_states_key_proj | |
| ) | |
| # attention | |
| query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) | |
| key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) | |
| value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) | |
| if image_rotary_emb is not None: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |
| attention_mask = (attention_mask > 0).bool() | |
| attention_mask = attention_mask.to( | |
| device=hidden_states.device, dtype=query.dtype | |
| ) | |
| original_hidden_states = hidden_states | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, | |
| key, | |
| value, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| attn_mask=attention_mask, | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| layer_scale = layer_scale.view(-1, 1, 1) | |
| if encoder_hidden_states is not None: | |
| encoder_hidden_states, hidden_states = ( | |
| hidden_states[:, : encoder_hidden_states.shape[1]], | |
| hidden_states[:, encoder_hidden_states.shape[1] :], | |
| ) | |
| # Final injection of ip addapter hidden_states | |
| if ip_hidden_states != None: | |
| hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | |
| return hidden_states, encoder_hidden_states | |
| else: | |
| # Final injection of ip addapter hidden_states | |
| if ip_hidden_states != None: | |
| hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states | |
| if attn.to_out is not None: | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| class ImageProjModel(nn.Module): | |
| def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16): | |
| super().__init__() | |
| self.num_tokens = num_tokens | |
| self.cross_attention_dim = cross_attention_dim | |
| self.clip_dim = clip_dim | |
| self.proj = torch.nn.Sequential( | |
| torch.nn.Linear(clip_dim,clip_dim*2), | |
| torch.nn.GELU(), | |
| torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens), | |
| ) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| def forward(self,input): | |
| raw_proj = self.proj(input) | |
| reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim) | |
| reshaped_proj = self.norm( reshaped_proj ) | |
| return reshaped_proj | |
| class LibreFluxIPAdapter(nn.Module): | |
| def __init__(self, transformer, image_proj_model, checkpoint=None): | |
| super().__init__() | |
| self.transformer = transformer | |
| self.image_proj_model = image_proj_model | |
| # Using startswith uses only double transformer blocks, and skips the single transformer blocks | |
| self.culled_transformer_blocks = {} | |
| for name, module in self.transformer.named_modules(): | |
| if isinstance(module, Attention): | |
| if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'): | |
| #print (f"Using Transformer: {name}") | |
| self.culled_transformer_blocks[name] = module | |
| #else: | |
| #print (f"Ignoring Transformer: {name}") | |
| # Apply the adapter to the culled blocks | |
| self.wrap_attention_blocks() | |
| if checkpoint: | |
| self.load_from_checkpoint(checkpoint) | |
| def wrap_attention_blocks(self,scale=1.0, num_tokens=16): | |
| """ Inject the IP-Adapter modules into the Transformer model """ | |
| sample_attn = self.transformer.transformer_blocks[0].attn | |
| hidden_size = sample_attn.inner_dim | |
| cross_attention_dim = sample_attn.cross_attention_dim | |
| num_heads = sample_attn.heads | |
| scale = 1.0 | |
| num_tokens = 16 | |
| processor_list = [] | |
| for name in self.culled_transformer_blocks: | |
| module = self.culled_transformer_blocks[name] | |
| module.processor = IPFluxAttnProcessor2_0( | |
| hidden_size= hidden_size, | |
| cross_attention_dim=4096, | |
| num_heads=num_heads, | |
| scale=1.0, | |
| num_tokens=16, | |
| ) | |
| processor_list.append(module.processor ) | |
| lay_count = len(processor_list) | |
| print (f"Added Attention IP Wrapper to {lay_count} layers") | |
| # Store adapters as a module list for saving/loading | |
| self.adapter_modules = torch.nn.ModuleList(processor_list) | |
| def parameters(self): | |
| """ Easy way to return all params """ | |
| # Apply adapter | |
| adapter_param_list = [] | |
| for name in self.culled_transformer_blocks: | |
| module = self.culled_transformer_blocks[name] | |
| adapter_param_list.append(module.processor.parameters()) | |
| all_params = chain(*adapter_param_list,self.image_proj_model.parameters()) | |
| return all_params | |
| def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs): | |
| """ Run projection and run forward """ | |
| mod_dtype = next(self.image_proj_model.parameters()).dtype | |
| mod_device = next(self.image_proj_model.parameters()).device | |
| ip_encoder_hidden_states = None | |
| if ref_image != None: | |
| ip_encoder_hidden_states = self.image_proj_model(ref_image) | |
| # Add ip hidden states to kwargs | |
| if 'joint_attention_kwargs' not in kwargs: | |
| kwargs['joint_attention_kwargs'] = {} | |
| layer_scale = layer_scale.to(dtype=mod_dtype, | |
| device=mod_device) | |
| kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale | |
| kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states | |
| output = self.transformer(*args, | |
| **kwargs) | |
| return output | |
| def save_pretrained(self,ckpt_path): | |
| """ Save model weights """ | |
| state_dict = {} | |
| state_dict["image_proj"] = self.image_proj_model.state_dict() | |
| state_dict["ip_adapter"] = self.adapter_modules.state_dict() | |
| torch.save(state_dict, ckpt_path) | |
| def load_from_checkpoint(self, ckpt_path): | |
| """ Loader ripped from tencent repo """ | |
| # Calculate original checksums | |
| orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
| orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| # Load state dict for image_proj_model and adapter_modules | |
| self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) | |
| self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) | |
| # Calculate new checksums | |
| new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
| new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
| # Verify if the weights have changed | |
| assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" | |
| assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" | |
| print(f"Successfully loaded weights from checkpoint {ckpt_path}") | |
| def dtype(self): | |
| return next(self.image_proj_model.parameters()).dtype | |