From 3d78f9d17d26b348755a342a2547e90d49842afb Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 9 Jan 2026 19:45:20 +0700 Subject: [PATCH 01/26] add constants for distill sigmas values and allow ltx pipeline to pass in sigmas --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 7 ++++++- src/diffusers/pipelines/ltx2/utils.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/ltx2/utils.py diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 99d6b71ec3d7..7d1cd3ce656c 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -750,6 +750,7 @@ def __call__( num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, @@ -788,6 +789,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -958,7 +963,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py new file mode 100644 index 000000000000..99c82c82c1af --- /dev/null +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -0,0 +1,4 @@ +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] \ No newline at end of file From 9c754a46aa768d30216a5580a91a2923e25cbf8a Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 11 Jan 2026 22:13:58 +0700 Subject: [PATCH 02/26] add time conditioning conversion and token packing for latents --- scripts/convert_ltx2_to_diffusers.py | 15 +++++++++------ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 5367113365a2..2794feffed6f 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -63,6 +63,8 @@ "up_blocks.4": "up_blocks.1", "up_blocks.5": "up_blocks.2.upsamplers.0", "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", # Common # For all 3D ResNets "res_blocks": "resnets", @@ -372,7 +374,7 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -396,7 +398,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -433,7 +435,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -450,8 +452,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] with init_empty_weights(): @@ -717,6 +719,7 @@ def get_args(): help="Latent upsampler filename", ) + parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model") parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -786,7 +789,7 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 7d1cd3ce656c..c662d9c16745 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -653,6 +653,11 @@ def prepare_latents( latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: + if latents.ndim == 5: + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) return latents.to(device=device, dtype=dtype) height = height // self.vae_spatial_compression_ratio @@ -694,6 +699,9 @@ def prepare_audio_latents( latent_length = round(duration_s * latents_per_second) if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) return latents.to(device=device, dtype=dtype), latent_length # TODO: confirm whether this logic is correct @@ -1097,6 +1105,8 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) + prenorm_latents = latents + prenorm_audio_latents = audio_latents latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) From 6fbeacf53bcc4b3c6281eb5e52e1bd81cf152555 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 11 Jan 2026 23:00:02 +0700 Subject: [PATCH 03/26] make style & quality --- scripts/convert_ltx2_to_diffusers.py | 16 ++++++++++++---- src/diffusers/pipelines/ltx2/utils.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 2794feffed6f..72b334b71e71 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -374,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config( + version: str, timestep_conditioning: bool = False +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -452,7 +454,9 @@ def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]: +def convert_ltx2_video_vae( + original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool +) -> Dict[str, Any]: config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] @@ -719,7 +723,9 @@ def get_args(): help="Latent upsampler filename", ) - parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model") + parser.add_argument( + "--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model" + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -789,7 +795,9 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning) + vae = convert_ltx2_video_vae( + original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning + ) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index 99c82c82c1af..bd0ae08c1073 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,4 +1,4 @@ DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] # Reduced schedule for super-resolution stage 2 (subset of distilled values) -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] \ No newline at end of file +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] From 82c2e7f068692ae689eaf872030503f5b8024eaf Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 11 Jan 2026 23:01:34 +0700 Subject: [PATCH 04/26] remove prenorm --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index c662d9c16745..b26ccca55c0a 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1105,8 +1105,6 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - prenorm_latents = latents - prenorm_audio_latents = audio_latents latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) From 837fd85c76148b1703636a8c321ce3ff163d8ab9 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 12 Jan 2026 11:26:31 +0700 Subject: [PATCH 05/26] add sigma param to ltx2 i2v --- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index b1711e283191..a33462f70c50 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -811,6 +811,7 @@ def __call__( num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, @@ -851,6 +852,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -1028,7 +1033,7 @@ def __call__( latent_width = width // self.vae_spatial_compression_ratio video_sequence_length = latent_num_frames * latent_height * latent_width - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), From 96fbcd8301a81deea58773a74ca620f12d70cebc Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 12 Jan 2026 11:30:31 +0700 Subject: [PATCH 06/26] fix copies and add pack latents to i2v --- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index a33462f70c50..92206cee4e9b 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -689,6 +689,11 @@ def prepare_latents( conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ).squeeze(-1) + if latents.ndim == 5: + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." @@ -754,6 +759,9 @@ def prepare_audio_latents( latent_length = round(duration_s * latents_per_second) if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) return latents.to(device=device, dtype=dtype), latent_length # TODO: confirm whether this logic is correct From 9575e0632afb5f02833b33e4d4f1da37274779e7 Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Tue, 13 Jan 2026 12:01:03 +0700 Subject: [PATCH 07/26] Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/ltx2/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index bd0ae08c1073..7e143edc9bb6 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,4 +1,4 @@ -DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] # Reduced schedule for super-resolution stage 2 (subset of distilled values) -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] From eb01780ada0f9c2f9aa3c2a6b9c48329295985c3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 14 Jan 2026 03:09:53 +0100 Subject: [PATCH 08/26] Infer latent dims if latents/audio_latents is supplied --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 47 ++++++++------- .../ltx2/pipeline_ltx2_image2video.py | 57 +++++++++++-------- 2 files changed, 61 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 588e05737e88..54f1061da5c9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -682,32 +682,23 @@ def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) - return latents.to(device=device, dtype=dtype), latent_length + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -717,7 +708,7 @@ def prepare_audio_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -935,6 +926,14 @@ def __call__( latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -950,20 +949,30 @@ def __call__( latents, ) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 92206cee4e9b..460ff8eec7de 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -742,32 +742,23 @@ def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) - return latents.to(device=device, dtype=dtype), latent_length + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -777,7 +768,7 @@ def prepare_audio_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -995,6 +986,19 @@ def __call__( ) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) image = image.to(device=device, dtype=prompt_embeds.dtype) @@ -1015,20 +1019,30 @@ def __call__( if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, @@ -1036,11 +1050,6 @@ def __call__( ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, From 7574bf991132f4fc2aab8def4b7855ff2f8a38b7 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 14 Jan 2026 22:50:38 +0700 Subject: [PATCH 09/26] add note for predefined sigmas --- src/diffusers/pipelines/ltx2/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index 7e143edc9bb6..8b790a4df0cb 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,3 +1,5 @@ +# Pre-trained sigma values for distilled model are taken from +# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] # Reduced schedule for super-resolution stage 2 (subset of distilled values) From c22eed5a8ef7e471e3607fc20417eee8dc39826e Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 14 Jan 2026 22:51:35 +0700 Subject: [PATCH 10/26] run make style and quality --- src/diffusers/pipelines/ltx2/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index 8b790a4df0cb..f80469817fe6 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,4 +1,4 @@ -# Pre-trained sigma values for distilled model are taken from +# Pre-trained sigma values for distilled model are taken from # https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] From c282485d0b03f1052b00aee795c95658564e5a2d Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Tue, 20 Jan 2026 23:04:57 +0700 Subject: [PATCH 11/26] revert distill timesteps & set original_state_dict_repo_idd to default None --- scripts/convert_ltx2_to_diffusers.py | 2 +- src/diffusers/pipelines/ltx2/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 72b334b71e71..e6a9aea4e46c 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -667,7 +667,7 @@ def get_args(): parser.add_argument( "--original_state_dict_repo_id", - default="Lightricks/LTX-2", + default=None, type=str, help="HF Hub repo id with LTX 2.0 checkpoint", ) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index f80469817fe6..77a0e3a883a3 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,6 +1,6 @@ # Pre-trained sigma values for distilled model are taken from # https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py -DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] # Reduced schedule for super-resolution stage 2 (subset of distilled values) -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] From 62acd4cf85f49ae1c6b65cae38c3be3bef6699c5 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 21 Jan 2026 00:01:43 +0700 Subject: [PATCH 12/26] add latent normalize --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 7 +++++++ src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 54f1061da5c9..b72a1079f92f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -594,6 +594,12 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): latents_mean = latents_mean.to(latents.device, latents.dtype) @@ -693,6 +699,7 @@ def prepare_audio_latents( if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 460ff8eec7de..a316f5307130 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -656,6 +656,13 @@ def _unpack_audio_latents( latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): @@ -753,6 +760,7 @@ def prepare_audio_latents( if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct From 7a56648d9c3f99df7c7464f5d534b9f28665bfb6 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 23 Jan 2026 11:48:56 +0700 Subject: [PATCH 13/26] add create noised state, delete last sigmas --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 16 ++++++++++++++++ src/diffusers/pipelines/ltx2/utils.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index b72a1079f92f..223bf459968b 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -605,6 +605,12 @@ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor latents_mean = latents_mean.to(latents.device, latents.dtype) latents_std = latents_std.to(latents.device, latents.dtype) return (latents * latents_std) + latents_mean + + @staticmethod + def _create_noised_state(latents: torch.Tensor, noise_scale: float, generator: Optional[torch.Generator] = None): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents @staticmethod def _pack_audio_latents( @@ -653,6 +659,7 @@ def prepare_latents( height: int = 512, width: int = 768, num_frames: int = 121, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -664,6 +671,7 @@ def prepare_latents( latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) + latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) height = height // self.vae_spatial_compression_ratio @@ -690,6 +698,7 @@ def prepare_audio_latents( num_channels_latents: int = 8, audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -700,6 +709,7 @@ def prepare_audio_latents( # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct @@ -760,6 +770,7 @@ def __call__( timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, + noise_scale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -815,6 +826,9 @@ def __call__( [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -950,6 +964,7 @@ def __call__( height, width, num_frames, + noise_scale, torch.float32, device, generator, @@ -980,6 +995,7 @@ def __call__( num_channels_latents=num_channels_latents_audio, audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, + noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index 77a0e3a883a3..f80469817fe6 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,6 +1,6 @@ # Pre-trained sigma values for distilled model are taken from # https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py -DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] # Reduced schedule for super-resolution stage 2 (subset of distilled values) -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] From 68788f7d3c5685fc62078a001dbdabcef7484dc6 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 26 Jan 2026 10:50:41 +0700 Subject: [PATCH 14/26] remove normalize step in latent upsample pipeline and move it to ltx2 pipeline --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 22 +++++++++++++++++++ .../ltx2/pipeline_ltx2_latent_upsample.py | 14 ------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 223bf459968b..c7345f3a1947 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -584,6 +584,17 @@ def _unpack_latents( latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + @staticmethod def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 @@ -667,10 +678,17 @@ def prepare_latents( ) -> torch.Tensor: if latents is not None: if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) # latents are of shape [B, C, F, H, W], need to be packed latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) @@ -708,6 +726,10 @@ def prepare_audio_latents( if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index a44c40b0430f..340efd10f24f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -228,17 +228,6 @@ def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.T filtered = latents * scales return filtered - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = (latents - latents_mean) * scaling_factor / latents_std - return latents - @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents def _denormalize_latents( @@ -408,9 +397,6 @@ def __call__( latents = self.tone_map_latents(latents, tone_map_compression_ratio) if output_type == "latent": - latents = self._normalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) video = latents else: if not self.vae.config.timestep_conditioning: From 8a9179b62abe130feb0299b059c89e0f623e2401 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 26 Jan 2026 15:00:35 +0700 Subject: [PATCH 15/26] add create noise latent to i2v pipeline --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 2 +- .../ltx2/pipeline_ltx2_image2video.py | 26 ++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index c7345f3a1947..001d3f23d577 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -616,7 +616,7 @@ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor latents_mean = latents_mean.to(latents.device, latents.dtype) latents_std = latents_std.to(latents.device, latents.dtype) return (latents * latents_std) + latents_mean - + @staticmethod def _create_noised_state(latents: torch.Tensor, noise_scale: float, generator: Optional[torch.Generator] = None): noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index a316f5307130..97f5411c8cf2 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -614,6 +614,13 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state(latents: torch.Tensor, noise_scale: float, generator: Optional[torch.Generator] = None): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents def _pack_audio_latents( @@ -678,6 +685,7 @@ def prepare_latents( height: int = 512, width: int = 704, num_frames: int = 161, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -693,14 +701,18 @@ def prepare_latents( if latents is not None: conditioning_mask = latents.new_zeros(mask_shape) conditioning_mask[:, :, 0] = 1.0 - conditioning_mask = self._pack_latents( - conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ).squeeze(-1) if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) # latents are of shape [B, C, F, H, W], need to be packed latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." @@ -751,6 +763,7 @@ def prepare_audio_latents( num_channels_latents: int = 8, audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, + noise_scale: float = 0.0, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -761,6 +774,7 @@ def prepare_audio_latents( # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct @@ -822,6 +836,7 @@ def __call__( timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, + noise_scale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -879,6 +894,9 @@ def __call__( [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -1019,6 +1037,7 @@ def __call__( height, width, num_frames, + noise_scale, torch.float32, device, generator, @@ -1051,6 +1070,7 @@ def __call__( num_channels_latents=num_channels_latents_audio, audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, + noise_scale=noise_scale, dtype=torch.float32, device=device, generator=generator, From 7e637bea1b45ad6b3d64dd35d3ac8773fb3fb2d0 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 26 Jan 2026 15:33:04 +0700 Subject: [PATCH 16/26] fix copies --- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 97f5411c8cf2..195bb4fb00a5 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -773,6 +773,10 @@ def prepare_audio_latents( if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) From 121c085a46d54a6a655cbc66fc103af3722c484f Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Tue, 27 Jan 2026 12:30:24 +0700 Subject: [PATCH 17/26] parse none value in weight conversion script --- scripts/convert_ltx2_to_diffusers.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index e6a9aea4e46c..88a9c8700a38 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -665,10 +665,15 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi def get_args(): parser = argparse.ArgumentParser() + def none_or_str(value: str): + if isinstance(value, str) and value.lower() == "none": + return None + return value + parser.add_argument( "--original_state_dict_repo_id", - default=None, - type=str, + default="Lightricks/LTX-2", + type=none_or_str, help="HF Hub repo id with LTX 2.0 checkpoint", ) parser.add_argument( @@ -688,7 +693,7 @@ def get_args(): parser.add_argument( "--combined_filename", default="ltx-2-19b-dev.safetensors", - type=str, + type=none_or_str, help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", ) parser.add_argument("--vae_prefix", default="vae.", type=str) @@ -707,19 +712,19 @@ def get_args(): parser.add_argument( "--text_encoder_model_id", default="google/gemma-3-12b-it-qat-q4_0-unquantized", - type=str, + type=none_or_str, help="HF Hub id for the LTX 2.0 base text encoder model", ) parser.add_argument( "--tokenizer_id", default="google/gemma-3-12b-it-qat-q4_0-unquantized", - type=str, + type=none_or_str, help="HF Hub id for the LTX 2.0 text tokenizer", ) parser.add_argument( "--latent_upsampler_filename", default="ltx-2-spatial-upscaler-x2-1.0.safetensors", - type=str, + type=none_or_str, help="Latent upsampler filename", ) From d0650d74d70bd20e676e017b1dbb989e9ac81f18 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Tue, 27 Jan 2026 12:50:52 +0700 Subject: [PATCH 18/26] explicit shape handling --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 20 ++++++++++++++++--- .../ltx2/pipeline_ltx2_image2video.py | 20 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 001d3f23d577..f73e49d1cb34 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -971,12 +971,19 @@ def __call__( latent_width = width // self.vae_spatial_compression_ratio if latents is not None: if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] - else: + elif latents.ndim == 3: logger.warning( f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -1000,11 +1007,18 @@ def __call__( audio_num_frames = round(duration_s * audio_latents_per_second) if audio_latents is not None: if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] - else: + elif audio_latents.ndim == 3: logger.warning( f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" - f" cannot be inferred. Make sure the supplied `num_frames` is correct." + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." ) num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 195bb4fb00a5..71a5572c650d 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -1021,12 +1021,19 @@ def __call__( latent_width = width // self.vae_spatial_compression_ratio if latents is not None: if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] - else: + elif latents.ndim == 3: logger.warning( f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) video_sequence_length = latent_num_frames * latent_height * latent_width if latents is None: @@ -1057,11 +1064,18 @@ def __call__( audio_num_frames = round(duration_s * audio_latents_per_second) if audio_latents is not None: if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] - else: + elif audio_latents.ndim == 3: logger.warning( f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" - f" cannot be inferred. Make sure the supplied `num_frames` is correct." + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." ) num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 From 1e6a8b978d5366d69ddbc3e96e3c7247a57bac1f Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Tue, 27 Jan 2026 12:52:42 +0700 Subject: [PATCH 19/26] Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 2 +- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index f73e49d1cb34..9a98747c4cf9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -618,7 +618,7 @@ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor return (latents * latents_std) + latents_mean @staticmethod - def _create_noised_state(latents: torch.Tensor, noise_scale: float, generator: Optional[torch.Generator] = None): + def _create_noised_state(latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None): noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) noised_latents = noise_scale * noise + (1 - noise_scale) * latents return noised_latents diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 71a5572c650d..a76157d85597 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -616,7 +616,7 @@ def _denormalize_latents( @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state - def _create_noised_state(latents: torch.Tensor, noise_scale: float, generator: Optional[torch.Generator] = None): + def _create_noised_state(latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None): noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) noised_latents = noise_scale * noise + (1 - noise_scale) * latents return noised_latents From f6f682f3f3db9ab69d47c21b9245152933403069 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Tue, 27 Jan 2026 13:03:34 +0700 Subject: [PATCH 20/26] make style --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 4 +++- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 9a98747c4cf9..a92a7a2c8869 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -618,7 +618,9 @@ def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor return (latents * latents_std) + latents_mean @staticmethod - def _create_noised_state(latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None): + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) noised_latents = noise_scale * noise + (1 - noise_scale) * latents return noised_latents diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index a76157d85597..04d7ee89c52a 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -616,7 +616,9 @@ def _denormalize_latents( @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state - def _create_noised_state(latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None): + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) noised_latents = noise_scale * noise + (1 - noise_scale) * latents return noised_latents From 12f25140a953be34a68e3c64b4a1f655ca7b3386 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 28 Jan 2026 00:27:19 +0700 Subject: [PATCH 21/26] add two stage inference tests --- .../autoencoders/autoencoder_kl_ltx2_audio.py | 4 +- tests/pipelines/ltx2/test_ltx2.py | 52 ++++++++++++++++++- tests/pipelines/ltx2/test_ltx2_image2video.py | 50 ++++++++++++++++++ 3 files changed, 103 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 6c9c7dce3d2f..b29629a29f80 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -743,8 +743,8 @@ def __init__( # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over # the entire dataset and stored in model's checkpoint under AudioVAE state_dict - latents_std = torch.zeros((base_channels,)) - latents_mean = torch.ones((base_channels,)) + latents_std = torch.ones((base_channels,)) + latents_mean = torch.zeros((base_channels,)) self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py index 6ffc23725022..7d1a3bfc9987 100644 --- a/tests/pipelines/ltx2/test_ltx2.py +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -222,7 +222,57 @@ def test_inference(self): ) expected_audio_slice = torch.tensor( [ - 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0263, 0.0528, 0.1217, 0.1104, 0.1632, 0.1072, 0.1789, 0.0949, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.5514, 0.5943, 0.4260, 0.5971, 0.4306, 0.6369, 0.3124, 0.6964, 0.5419, 0.2412, 0.3882, 0.4504, 0.1941, 0.3404, 0.6037, 0.2464 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0252, 0.0526, 0.1211, 0.1119, 0.1638, 0.1042, 0.1776, 0.0948, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index 1edae9c0e098..95b42a38da26 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -237,5 +237,55 @@ def test_inference(self): assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.2665, 0.6915, 0.2939, 0.6767, 0.2552, 0.6215, 0.1765, 0.6248, 0.2800, 0.2356, 0.3480, 0.5395, 0.3190, 0.4128, 0.4784, 0.4086 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0273, 0.0490, 0.1253, 0.1129, 0.1655, 0.1057, 0.1707, 0.0943, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) From ce6adfb70dd54625f20cfd311e178f95c153e8a8 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 28 Jan 2026 00:53:14 +0700 Subject: [PATCH 22/26] add ltx2 documentation --- docs/source/en/api/pipelines/ltx2.md | 214 +++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 4c6860daf024..6898afa1064a 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -24,6 +24,220 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2). +## One-stage Generation + +Sample usage of text-to-video one stage pipeline + +```py +import torch +from diffusers.pipelines.ltx2 import LTX2Pipeline +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image + +pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) +pipe.enable_model_cpu_offload() + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +frame_rate = 24.0 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=768, + height=512, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + guidance_scale=4.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_sample.mp4", +) +``` + +## Two-stages Generation +Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages: + +- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning. +- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness. + +Sample usage of text-to-video two stages pipeline + +```py +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda:0" +width = 768 +height = 512 + +pipe = LTX2Pipeline.from_pretrained( + "Lightricks/LTX-2", torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +# Stage 1 default (non-distilled) inference +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + sigmas=None, + guidance_scale=4.0, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + "Lightricks/LTX-2", + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +# Load Stage 2 distilled LoRA +pipe.load_lora_weights( + "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors" +) +pipe.set_adapters("stage_2_distilled", 1.0) +# VAE tiling seems necessary to avoid OOM error when VAE decoding +pipe.vae.enable_tiling() +# Change scheduler to use Stage 2 distilled sigmas as is +new_scheduler = FlowMatchEulerDiscreteScheduler.from_config( + pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None +) +pipe.scheduler = new_scheduler +# Stage 2 inference with distilled LoRA and sigmas +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_lora_distilled_sample.mp4", +) +``` + +## Distilled checkpoint generation +Fastest two-stages generation pipeline using a distilled checkpoint. + +```py +import torch +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2Pipeline.from_pretrained( + model_path, torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=8, + sigmas=DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + generator=generator, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + model_path, + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + generator=generator, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_distilled_sample.mp4", +) +``` + ## LTX2Pipeline [[autodoc]] LTX2Pipeline From 0c70c8b9a4f4f4ef5e05d42f4b22048e7a6b8d82 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 28 Jan 2026 01:43:07 +0700 Subject: [PATCH 23/26] update i2v expected_audio_slice --- tests/pipelines/ltx2/test_ltx2_image2video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index 95b42a38da26..3653e1cfc5e4 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -224,7 +224,7 @@ def test_inference(self): ) expected_audio_slice = torch.tensor( [ - 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0294, 0.0498, 0.1269, 0.1135, 0.1639, 0.1116, 0.1730, 0.0931, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on From 7303754769dc1a5b44b5c3f01a6ab05557a4dc88 Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Wed, 28 Jan 2026 08:42:35 +0700 Subject: [PATCH 24/26] Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- docs/source/en/api/pipelines/ltx2.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 6898afa1064a..1e3f01758b29 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -78,7 +78,7 @@ import torch from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel -from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES from diffusers.pipelines.ltx2.export_utils import encode_video device = "cuda:0" @@ -127,7 +127,7 @@ pipe.load_lora_weights( "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors" ) pipe.set_adapters("stage_2_distilled", 1.0) -# VAE tiling seems necessary to avoid OOM error when VAE decoding +# VAE tiling is usually necessary to avoid OOM error when VAE decoding pipe.vae.enable_tiling() # Change scheduler to use Stage 2 distilled sigmas as is new_scheduler = FlowMatchEulerDiscreteScheduler.from_config( From 833b42718748adbff7319d70f8e536e4c4e38f5f Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Wed, 28 Jan 2026 08:43:34 +0700 Subject: [PATCH 25/26] Apply suggestion from @dg845 Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- docs/source/en/api/pipelines/ltx2.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 1e3f01758b29..2d5586a60c65 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -32,7 +32,6 @@ Sample usage of text-to-video one stage pipeline import torch from diffusers.pipelines.ltx2 import LTX2Pipeline from diffusers.pipelines.ltx2.export_utils import encode_video -from diffusers.utils import load_image pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() From 0191986a239eeb7a2fa8f2efbaf943fe9a360d1e Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Wed, 28 Jan 2026 21:39:30 +0700 Subject: [PATCH 26/26] Update ltx2.md to remove one-stage example Removed one-stage generation example code and added comments for noise scale in two-stage generation. --- docs/source/en/api/pipelines/ltx2.md | 44 ++-------------------------- 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 2d5586a60c65..24776b42309e 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -24,46 +24,6 @@ You can find all the original LTX-Video checkpoints under the [Lightricks](https The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2). -## One-stage Generation - -Sample usage of text-to-video one stage pipeline - -```py -import torch -from diffusers.pipelines.ltx2 import LTX2Pipeline -from diffusers.pipelines.ltx2.export_utils import encode_video - -pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) -pipe.enable_model_cpu_offload() - -prompt = "A beautiful sunset over the ocean" -negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." - -frame_rate = 24.0 -video, audio = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - width=768, - height=512, - num_frames=121, - frame_rate=frame_rate, - num_inference_steps=40, - guidance_scale=4.0, - output_type="np", - return_dict=False, -) -video = (video * 255).round().astype("uint8") -video = torch.from_numpy(video) - -encode_video( - video[0], - fps=frame_rate, - audio=audio[0].float().cpu(), - audio_sample_rate=pipe.vocoder.config.output_sampling_rate, - output_path="ltx2_sample.mp4", -) -``` - ## Two-stages Generation Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages: @@ -140,7 +100,7 @@ video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=3, - noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218 sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, guidance_scale=1.0, output_type="np", @@ -218,7 +178,7 @@ video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=3, - noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178 sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, generator=generator, guidance_scale=1.0,