Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions src/diffusers/modular_pipelines/flux/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
import torch

from ...pipelines import FluxPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
Expand All @@ -29,6 +28,28 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def _pack_latents(latents, batch_size, num_channels_latents, height, width):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try using "# Copied from ..."?

latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents


def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)

return latent_image_ids.to(device=device, dtype=dtype)


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
Expand Down Expand Up @@ -390,7 +411,7 @@ def prepare_latents(

# TODO: move packing latents code to a patchifier similar to Qwen
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)

return latents

Expand Down Expand Up @@ -541,7 +562,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip

height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
block_state.img_ids = _prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)

self.set_block_state(state, block_state)

Expand Down Expand Up @@ -598,15 +619,15 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
):
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
img_ids = FluxPipeline._prepare_latent_image_ids(
img_ids = _prepare_latent_image_ids(
None, image_latent_height // 2, image_latent_width // 2, device, dtype
)
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
img_ids[..., 0] = 1

height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
latent_ids = _prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)

if img_ids is not None:
latent_ids = torch.cat([latent_ids, img_ids], dim=0)
Expand Down
22 changes: 20 additions & 2 deletions src/diffusers/modular_pipelines/flux/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]


def basic_clean(text):
text = ftfy.fix_text(text)
Expand Down Expand Up @@ -170,8 +190,6 @@ def intermediate_outputs(self) -> list[OutputParam]:

@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState):
from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS

block_state = self.get_block_state(state)
images = block_state.image

Expand Down
93 changes: 87 additions & 6 deletions src/diffusers/modular_pipelines/flux/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,100 @@

import torch

from ...pipelines import FluxPipeline
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import InputParam, OutputParam

# TODO: consider making these common utilities for modular if they are not pipeline-specific.
from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
from .modular_pipeline import FluxModularPipeline


logger = logging.get_logger(__name__)


def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents


def repeat_tensor_to_batch_size(
input_name: str,
input_tensor: torch.Tensor,
batch_size: int,
num_images_per_prompt: int = 1,
) -> torch.Tensor:
"""Repeat tensor elements to match the final batch size.

This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
by repeating each element along dimension 0.

The input tensor must have batch size 1 or batch_size. The function will:
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
- If batch size equals batch_size: repeat each element num_images_per_prompt times

Args:
input_name (str): Name of the input tensor (used for error messages)
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
batch_size (int): The base batch size (number of prompts)
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.

Returns:
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)

Raises:
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
"""
# make sure input is a tensor
if not isinstance(input_tensor, torch.Tensor):
raise ValueError(f"`{input_name}` must be a tensor")

# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
if input_tensor.shape[0] == 1:
repeat_by = batch_size * num_images_per_prompt
elif input_tensor.shape[0] == batch_size:
repeat_by = num_images_per_prompt
else:
raise ValueError(
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
)

# expand the tensor to match the batch_size * num_images_per_prompt
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)

return input_tensor


def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> tuple[int, int]:
"""Calculate image dimensions from latent tensor dimensions.

This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
by the VAE scale factor.

Args:
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
vae_scale_factor (int): The scale factor used by the VAE to compress images.
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)

Returns:
tuple[int, int]: The calculated image dimensions as (height, width)

Raises:
ValueError: If latents tensor doesn't have 4 or 5 dimensions
"""
# make sure the latents are not packed
if latents.ndim != 4 and latents.ndim != 5:
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")

latent_height, latent_width = latents.shape[-2:]

height = latent_height * vae_scale_factor
width = latent_width * vae_scale_factor

return height, width


class FluxTextInputStep(ModularPipelineBlocks):
model_name = "flux"

Expand Down Expand Up @@ -209,7 +290,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
# 2. Patchify the image latent tensor
# TODO: Implement patchifier for Flux.
latent_height, latent_width = image_latent_tensor.shape[2:]
image_latent_tensor = FluxPipeline._pack_latents(
image_latent_tensor = _pack_latents(
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
)

Expand Down Expand Up @@ -266,7 +347,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
# 2. Patchify the image latent tensor
# TODO: Implement patchifier for Flux.
latent_height, latent_width = image_latent_tensor.shape[2:]
image_latent_tensor = FluxPipeline._pack_latents(
image_latent_tensor = _pack_latents(
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
)

Expand Down
Loading