diff-storyboard / examples /vram_management /flux_text_to_image.py
jiaxi2002's picture
Upload folder using huggingface_hub
feb33a0 verified
import torch
from diffsynth import ModelManager, FluxImagePipeline
model_manager = ModelManager(
file_path_list=[
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
"models/FLUX/FLUX.1-dev/ae.safetensors",
],
torch_dtype=torch.float8_e4m3fn,
device="cpu"
)
pipe = FluxImagePipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
# Enable VRAM management
# `num_persistent_param_in_dit` indicates the number of parameters that reside persistently in VRAM within the DiT model.
# When `num_persistent_param_in_dit=None`, it means all parameters reside persistently in memory.
# When `num_persistent_param_in_dit=7*10**9`, it indicates that 7 billion parameters reside persistently in memory.
# When `num_persistent_param_in_dit=0`, it means no parameters reside persistently in memory, and they are loaded layer by layer during inference.
pipe.enable_vram_management(num_persistent_param_in_dit=None)
image = pipe(prompt="a beautiful orange cat", seed=0)
image.save("image.jpg")