File size: 1,154 Bytes
feb33a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
from diffsynth import ModelManager, FluxImagePipeline
model_manager = ModelManager(
file_path_list=[
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
"models/FLUX/FLUX.1-dev/ae.safetensors",
],
torch_dtype=torch.float8_e4m3fn,
device="cpu"
)
pipe = FluxImagePipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
# Enable VRAM management
# `num_persistent_param_in_dit` indicates the number of parameters that reside persistently in VRAM within the DiT model.
# When `num_persistent_param_in_dit=None`, it means all parameters reside persistently in memory.
# When `num_persistent_param_in_dit=7*10**9`, it indicates that 7 billion parameters reside persistently in memory.
# When `num_persistent_param_in_dit=0`, it means no parameters reside persistently in memory, and they are loaded layer by layer during inference.
pipe.enable_vram_management(num_persistent_param_in_dit=None)
image = pipe(prompt="a beautiful orange cat", seed=0)
image.save("image.jpg")
|