diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index c28154775f5a..25e677aadeb5 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -17,7 +17,6 @@ import numpy as np import torch -from ...pipelines import FluxPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -29,6 +28,28 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + +def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -390,7 +411,7 @@ def prepare_latents( # TODO: move packing latents code to a patchifier similar to Qwen latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) return latents @@ -541,7 +562,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) - block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + block_state.img_ids = _prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) self.set_block_state(state, block_state) @@ -598,7 +619,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip ): image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) - img_ids = FluxPipeline._prepare_latent_image_ids( + img_ids = _prepare_latent_image_ids( None, image_latent_height // 2, image_latent_width // 2, device, dtype ) # image ids are the same as latent ids with the first dimension set to 1 instead of 0 @@ -606,7 +627,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) - latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + latent_ids = _prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) if img_ids is not None: latent_ids = torch.cat([latent_ids, img_ids], dim=0) diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index 583c139ff22e..c58675eefbf0 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -34,6 +34,26 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + def basic_clean(text): text = ftfy.fix_text(text) @@ -170,8 +190,6 @@ def intermediate_outputs(self) -> list[OutputParam]: @torch.no_grad() def __call__(self, components: FluxModularPipeline, state: PipelineState): - from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS - block_state = self.get_block_state(state) images = block_state.image diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py index 9d2f69dbe26f..b0cee24fa4fa 100644 --- a/src/diffusers/modular_pipelines/flux/inputs.py +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -15,19 +15,112 @@ import torch -from ...pipelines import FluxPipeline from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam - -# TODO: consider making these common utilities for modular if they are not pipeline-specific. -from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size from .modular_pipeline import FluxModularPipeline logger = logging.get_logger(__name__) +def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.repeat_tensor_to_batch_size +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.calculate_dimension_from_latents +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent space dimensions to image space dimensions by multiplying the latent height and width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress images. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + # make sure the latents are not packed + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}") + + latent_height, latent_width = latents.shape[-2:] + + height = latent_height * vae_scale_factor + width = latent_width * vae_scale_factor + + return height, width + + class FluxTextInputStep(ModularPipelineBlocks): model_name = "flux" @@ -209,7 +302,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip # 2. Patchify the image latent tensor # TODO: Implement patchifier for Flux. latent_height, latent_width = image_latent_tensor.shape[2:] - image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor = _pack_latents( image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width ) @@ -266,7 +359,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip # 2. Patchify the image latent tensor # TODO: Implement patchifier for Flux. latent_height, latent_width = image_latent_tensor.shape[2:] - image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor = _pack_latents( image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width )