Update pipeline_bria_controlnet.py
Browse files
pipeline_bria_controlnet.py
CHANGED
|
@@ -25,9 +25,9 @@ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
|
| 25 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 26 |
from diffusers.utils import logging
|
| 27 |
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 28 |
-
from
|
| 29 |
from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift
|
| 30 |
-
from
|
| 31 |
from transformer_bria import BriaTransformer2DModel
|
| 32 |
from bria_utils import get_original_sigmas
|
| 33 |
import numpy as np
|
|
@@ -397,7 +397,10 @@ class BriaControlNetPipeline(BriaPipeline):
|
|
| 397 |
|
| 398 |
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
|
| 399 |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
| 401 |
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
| 402 |
|
| 403 |
mu = calculate_shift(
|
|
|
|
| 25 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 26 |
from diffusers.utils import logging
|
| 27 |
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 28 |
+
from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
|
| 29 |
from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps, calculate_shift
|
| 30 |
+
from pipeline_bria import BriaPipeline
|
| 31 |
from transformer_bria import BriaTransformer2DModel
|
| 32 |
from bria_utils import get_original_sigmas
|
| 33 |
import numpy as np
|
|
|
|
| 397 |
|
| 398 |
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
|
| 399 |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 400 |
+
if type(control_image) == list:
|
| 401 |
+
image_seq_len = control_image[0].shape[1]
|
| 402 |
+
else:
|
| 403 |
+
image_seq_len = control_image.shape[1]
|
| 404 |
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
| 405 |
|
| 406 |
mu = calculate_shift(
|