diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 042cb5c10021..66b733b71d28 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -1,927 +1,929 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass - +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + import numpy as np import torch import torch.nn as nn -from ...utils import BaseOutput +from ...utils import BaseOutput, logging from ...utils.torch_utils import randn_tensor from ..activations import get_activation from ..attention_processor import SpatialNorm from ..unets.unet_2d_blocks import ( - AutoencoderTinyBlock, - UNetMidBlock2D, - get_down_block, + AutoencoderTinyBlock, + UNetMidBlock2D, + get_down_block, get_up_block, ) - -@dataclass -class EncoderOutput(BaseOutput): - r""" - Output of encoding method. - - Args: - latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`): - The encoded latent. - """ - - latent: torch.Tensor +logger = logging.get_logger(__name__) @dataclass -class DecoderOutput(BaseOutput): - r""" - Output of decoding method. - - Args: - sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): - The decoded output sample from the last layer of the model. - """ - - sample: torch.Tensor - commit_loss: torch.FloatTensor | None = None - - -class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available - options. - block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - double_z (`bool`, *optional*, defaults to `True`): - Whether to double the number of output channels for the last block. - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",), - block_out_channels: tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - mid_block_add_attention=True, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - ) - - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attention_head_dim=output_channel, - temb_channels=None, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - add_attention=mid_block_add_attention, - ) - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `Encoder` class.""" - - sample = self.conv_in(sample) - - if torch.is_grad_enabled() and self.gradient_checkpointing: - # down - for down_block in self.down_blocks: - sample = self._gradient_checkpointing_func(down_block, sample) - # middle - sample = self._gradient_checkpointing_func(self.mid_block, sample) - - else: - # down - for down_block in self.down_blocks: - sample = down_block(sample) - - # middle - sample = self.mid_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - -class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. - block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - norm_type (`str`, *optional*, defaults to `"group"`): - The normalization type to use. Can be either `"group"` or `"spatial"`. - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), - block_out_channels: tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - mid_block_add_attention=True, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - add_attention=mid_block_add_attention, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attention_head_dim=output_channel, - temb_channels=temb_channels, - resnet_time_scale_shift=norm_type, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - self.gradient_checkpointing = False - - def forward( - self, - sample: torch.Tensor, - latent_embeds: torch.Tensor | None = None, - ) -> torch.Tensor: - r"""The forward method of the `Decoder` class.""" - - sample = self.conv_in(sample) - - if torch.is_grad_enabled() and self.gradient_checkpointing: - # middle - sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) - - # up - for up_block in self.up_blocks: - sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) - else: - # middle - sample = self.mid_block(sample, latent_embeds) - - # up - for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds) - - # post-process - if latent_embeds is None: - sample = self.conv_norm_out(sample) - else: - sample = self.conv_norm_out(sample, latent_embeds) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - -class UpSample(nn.Module): - r""" - The `UpSample` layer of a variational autoencoder that upsamples its input. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `UpSample` class.""" - x = torch.relu(x) - x = self.deconv(x) - return x - - -class MaskConditionEncoder(nn.Module): - """ - used in AsymmetricAutoencoderKL - """ - - def __init__( - self, - in_ch: int, - out_ch: int = 192, - res_ch: int = 768, - stride: int = 16, - ) -> None: - super().__init__() - - channels = [] - while stride > 1: - stride = stride // 2 - in_ch_ = out_ch * 2 - if out_ch > res_ch: - out_ch = res_ch - if stride == 1: - in_ch_ = res_ch - channels.append((in_ch_, out_ch)) - out_ch *= 2 - - out_channels = [] - for _in_ch, _out_ch in channels: - out_channels.append(_out_ch) - out_channels.append(channels[-1][0]) - - layers = [] - in_ch_ = in_ch - for l in range(len(out_channels)): - out_ch_ = out_channels[l] - if l == 0 or l == 1: - layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) - else: - layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) - in_ch_ = out_ch_ - - self.layers = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: - r"""The forward method of the `MaskConditionEncoder` class.""" - out = {} - for l in range(len(self.layers)): - layer = self.layers[l] - x = layer(x) - out[str(tuple(x.shape))] = x - x = torch.relu(x) - return out - - -class MaskConditionDecoder(nn.Module): - r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's - decoder with a conditioner on the mask and masked image. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. - block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - norm_type (`str`, *optional*, defaults to `"group"`): - The normalization type to use. Can be either `"group"` or `"spatial"`. - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), - block_out_channels: tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - prev_output_channel=None, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attention_head_dim=output_channel, - temb_channels=temb_channels, - resnet_time_scale_shift=norm_type, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # condition encoder - self.condition_encoder = MaskConditionEncoder( - in_ch=out_channels, - out_ch=block_out_channels[0], - res_ch=block_out_channels[-1], - ) - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - self.gradient_checkpointing = False - - def forward( - self, - z: torch.Tensor, - image: torch.Tensor | None = None, - mask: torch.Tensor | None = None, - latent_embeds: torch.Tensor | None = None, - ) -> torch.Tensor: - r"""The forward method of the `MaskConditionDecoder` class.""" - sample = z - sample = self.conv_in(sample) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if torch.is_grad_enabled() and self.gradient_checkpointing: - # middle - sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) - sample = sample.to(upscale_dtype) - - # condition encoder - if image is not None and mask is not None: - masked_image = (1 - mask) * image - im_x = self._gradient_checkpointing_func( - self.condition_encoder, - masked_image, - mask, - ) - - # up - for up_block in self.up_blocks: - if image is not None and mask is not None: - sample_ = im_x[str(tuple(sample.shape))] - mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") - sample = sample * mask_ + sample_ * (1 - mask_) - sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) - if image is not None and mask is not None: - sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) - else: - # middle - sample = self.mid_block(sample, latent_embeds) - sample = sample.to(upscale_dtype) - - # condition encoder - if image is not None and mask is not None: - masked_image = (1 - mask) * image - im_x = self.condition_encoder(masked_image, mask) - - # up - for up_block in self.up_blocks: - if image is not None and mask is not None: - sample_ = im_x[str(tuple(sample.shape))] - mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") - sample = sample * mask_ + sample_ * (1 - mask_) - sample = up_block(sample, latent_embeds) - if image is not None and mask is not None: - sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) - - # post-process - if latent_embeds is None: - sample = self.conv_norm_out(sample) - else: - sample = self.conv_norm_out(sample, latent_embeds) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__( - self, - n_e: int, - vq_embed_dim: int, - beta: float, - remap=None, - unknown_index: str = "random", - sane_index_shape: bool = False, - legacy: bool = True, - ): - super().__init__() - self.n_e = n_e - self.vq_embed_dim = vq_embed_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.used: torch.Tensor - self.re_embed = self.used.shape[0] +class EncoderOutput(BaseOutput): + r""" + Output of encoding method. + + Args: + latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`): + The encoded latent. + """ + + latent: torch.Tensor + + +@dataclass +class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.Tensor + commit_loss: torch.FloatTensor | None = None + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # down + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + latent_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class UpSample(nn.Module): + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `UpSample` class.""" + x = torch.relu(x) + x = self.deconv(x) + return x + + +class MaskConditionEncoder(nn.Module): + """ + used in AsymmetricAutoencoderKL + """ + + def __init__( + self, + in_ch: int, + out_ch: int = 192, + res_ch: int = 768, + stride: int = 16, + ) -> None: + super().__init__() + + channels = [] + while stride > 1: + stride = stride // 2 + in_ch_ = out_ch * 2 + if out_ch > res_ch: + out_ch = res_ch + if stride == 1: + in_ch_ = res_ch + channels.append((in_ch_, out_ch)) + out_ch *= 2 + + out_channels = [] + for _in_ch, _out_ch in channels: + out_channels.append(_out_ch) + out_channels.append(channels[-1][0]) + + layers = [] + in_ch_ = in_ch + for l in range(len(out_channels)): + out_ch_ = out_channels[l] + if l == 0 or l == 1: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) + else: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) + in_ch_ = out_ch_ + + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: + r"""The forward method of the `MaskConditionEncoder` class.""" + out = {} + for l in range(len(self.layers)): + layer = self.layers[l] + x = layer(x) + out[str(tuple(x.shape))] = x + x = torch.relu(x) + return out + + +class MaskConditionDecoder(nn.Module): + r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's + decoder with a conditioner on the mask and masked image. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # condition encoder + self.condition_encoder = MaskConditionEncoder( + in_ch=out_channels, + out_ch=block_out_channels[0], + res_ch=block_out_channels[-1], + ) + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + z: torch.Tensor, + image: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + latent_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + r"""The forward method of the `MaskConditionDecoder` class.""" + sample = z + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self._gradient_checkpointing_func( + self.condition_encoder, + masked_image, + mask, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self.condition_encoder(masked_image, mask) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = up_block(sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor + self.re_embed = self.used.shape[0] self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 - print( + logger.info( f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple]: - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.vq_embed_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) - - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q: torch.Tensor = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices: torch.LongTensor, shape: tuple[int, ...]) -> torch.Tensor: - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q: torch.Tensor = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters: torch.Tensor, deterministic: bool = False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like( - self.mean, device=self.parameters.device, dtype=self.parameters.dtype - ) - - def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: - # make sure sample is on the same device as the parameters and has same dtype - sample = randn_tensor( - self.mean.shape, - generator=generator, - device=self.parameters.device, - dtype=self.parameters.dtype, - ) - x = self.mean + self.std * sample - return x - - def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample: torch.Tensor, dims: tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self) -> torch.Tensor: - return self.mean - - -class IdentityDistribution(object): - def __init__(self, parameters: torch.Tensor): - self.parameters = parameters - - def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: - return self.parameters - - def mode(self) -> torch.Tensor: - return self.parameters - - -class EncoderTiny(nn.Module): - r""" - The `EncoderTiny` layer is a simpler version of the `Encoder` layer. - - Args: - in_channels (`int`): - The number of input channels. - out_channels (`int`): - The number of output channels. - num_blocks (`tuple[int, ...]`): - Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to - use. - block_out_channels (`tuple[int, ...]`): - The number of output channels for each block. - act_fn (`str`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - num_blocks: tuple[int, ...], - block_out_channels: tuple[int, ...], - act_fn: str, - ): - super().__init__() - - layers = [] - for i, num_block in enumerate(num_blocks): - num_channels = block_out_channels[i] - - if i == 0: - layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) - else: - layers.append( - nn.Conv2d( - num_channels, - num_channels, - kernel_size=3, - padding=1, - stride=2, - bias=False, - ) - ) - - for _ in range(num_block): - layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) - - layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) - - self.layers = nn.Sequential(*layers) - self.gradient_checkpointing = False - - def forward(self, x: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `EncoderTiny` class.""" - if torch.is_grad_enabled() and self.gradient_checkpointing: - x = self._gradient_checkpointing_func(self.layers, x) - - else: - # scale image from [-1, 1] to [0, 1] to match TAESD convention - x = self.layers(x.add(1).div(2)) - - return x - - -class DecoderTiny(nn.Module): - r""" - The `DecoderTiny` layer is a simpler version of the `Decoder` layer. - - Args: - in_channels (`int`): - The number of input channels. - out_channels (`int`): - The number of output channels. - num_blocks (`tuple[int, ...]`): - Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to - use. - block_out_channels (`tuple[int, ...]`): - The number of output channels for each block. - upsampling_scaling_factor (`int`): - The scaling factor to use for upsampling. - act_fn (`str`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - num_blocks: tuple[int, ...], - block_out_channels: tuple[int, ...], - upsampling_scaling_factor: int, - act_fn: str, - upsample_fn: str, - ): - super().__init__() - - layers = [ - nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), - get_activation(act_fn), - ] - - for i, num_block in enumerate(num_blocks): - is_final_block = i == (len(num_blocks) - 1) - num_channels = block_out_channels[i] - - for _ in range(num_block): - layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) - - if not is_final_block: - layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn)) - - conv_out_channel = num_channels if not is_final_block else out_channels - layers.append( - nn.Conv2d( - num_channels, - conv_out_channel, - kernel_size=3, - padding=1, - bias=is_final_block, - ) - ) - - self.layers = nn.Sequential(*layers) - self.gradient_checkpointing = False - - def forward(self, x: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `DecoderTiny` class.""" - # Clamp. - x = torch.tanh(x / 3) * 3 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - x = self._gradient_checkpointing_func(self.layers, x) - else: - x = self.layers(x) - - # scale image from [0, 1] to [-1, 1] to match diffusers convention - return x.mul(2).sub(1) - - -class AutoencoderMixin: - def enable_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - if not hasattr(self, "use_tiling"): - raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.") - self.use_tiling = True - - def disable_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - if not hasattr(self, "use_slicing"): - raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.") - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple]: + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q: torch.Tensor = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: torch.LongTensor, shape: tuple[int, ...]) -> torch.Tensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q: torch.Tensor = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +class IdentityDistribution(object): + def __init__(self, parameters: torch.Tensor): + self.parameters = parameters + + def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: + return self.parameters + + def mode(self) -> torch.Tensor: + return self.parameters + + +class EncoderTiny(nn.Module): + r""" + The `EncoderTiny` layer is a simpler version of the `Encoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`tuple[int, ...]`): + The number of output channels for each block. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: tuple[int, ...], + block_out_channels: tuple[int, ...], + act_fn: str, + ): + super().__init__() + + layers = [] + for i, num_block in enumerate(num_blocks): + num_channels = block_out_channels[i] + + if i == 0: + layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) + else: + layers.append( + nn.Conv2d( + num_channels, + num_channels, + kernel_size=3, + padding=1, + stride=2, + bias=False, + ) + ) + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `EncoderTiny` class.""" + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(self.layers, x) + + else: + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) + + return x + + +class DecoderTiny(nn.Module): + r""" + The `DecoderTiny` layer is a simpler version of the `Decoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`tuple[int, ...]`): + The number of output channels for each block. + upsampling_scaling_factor (`int`): + The scaling factor to use for upsampling. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: tuple[int, ...], + block_out_channels: tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + upsample_fn: str, + ): + super().__init__() + + layers = [ + nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), + get_activation(act_fn), + ] + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + if not is_final_block: + layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn)) + + conv_out_channel = num_channels if not is_final_block else out_channels + layers.append( + nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + ) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = torch.tanh(x / 3) * 3 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(self.layers, x) + else: + x = self.layers(x) + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) + + +class AutoencoderMixin: + def enable_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_tiling = True + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + if not hasattr(self, "use_slicing"): + raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index ce1606f0e859..c1a039687c42 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -1,129 +1,143 @@ # coding=utf-8 # Copyright 2025 HuggingFace Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest import torch from diffusers import VQModel +from diffusers.models.autoencoders.vae import VectorQuantizer from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import backend_manual_seed, enable_full_determinism, torch_device from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin from .testing_utils import NewAutoencoderTesterMixin - - -enable_full_determinism() - - -class VQModelTesterConfig(BaseModelTesterConfig): - @property - def model_class(self): - return VQModel - - @property - def main_input_name(self) -> str: - return "sample" - - @property - def output_shape(self) -> tuple: - return (3, 32, 32) - - @property - def generator(self): - return torch.Generator("cpu").manual_seed(0) - - def get_init_dict(self) -> dict: - return { - "block_out_channels": [8, 16], - "norm_num_groups": 8, - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 3, - } - - def get_dummy_inputs(self) -> dict: - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device) - return {"sample": image} - - + + +enable_full_determinism() + + +class VQModelTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return VQModel + + @property + def main_input_name(self) -> str: + return "sample" + + @property + def output_shape(self) -> tuple: + return (3, 32, 32) + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { + "block_out_channels": [8, 16], + "norm_num_groups": 8, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 3, + } + + def get_dummy_inputs(self) -> dict: + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device) + return {"sample": image} + + class TestVQModel(VQModelTesterConfig, ModelTesterMixin): - def test_from_pretrained_hub(self): - model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) - assert model is not None - assert len(loading_info["missing_keys"]) == 0 - - model.to(torch_device) - image = model(**self.get_dummy_inputs()) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = VQModel.from_pretrained("fusing/vqgan-dummy") - model.to(torch_device).eval() - - torch.manual_seed(0) - backend_manual_seed(torch_device, 0) - - image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) - image = image.to(torch_device) - with torch.no_grad(): - output = model(image).sample - - output_slice = output[0, -1, -3:, -3:].flatten().cpu() - # fmt: off - expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) - # fmt: on - assert torch.allclose(output_slice, expected_output_slice, atol=1e-3) - + def test_from_pretrained_hub(self): + model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 + + model.to(torch_device) + image = model(**self.get_dummy_inputs()) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.to(torch_device).eval() + + torch.manual_seed(0) + backend_manual_seed(torch_device, 0) + + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + output = model(image).sample + + output_slice = output[0, -1, -3:, -3:].flatten().cpu() + # fmt: off + expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) + # fmt: on + assert torch.allclose(output_slice, expected_output_slice, atol=1e-3) + def test_loss_pretrained(self): model = VQModel.from_pretrained("fusing/vqgan-dummy") model.to(torch_device).eval() - - torch.manual_seed(0) - backend_manual_seed(torch_device, 0) - - image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) - image = image.to(torch_device) - with torch.no_grad(): - output = model(image).commit_loss.cpu() - # fmt: off + + torch.manual_seed(0) + backend_manual_seed(torch_device, 0) + + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + output = model(image).commit_loss.cpu() + # fmt: off expected_output = torch.tensor([0.1936]) # fmt: on assert torch.allclose(output, expected_output, atol=1e-3) - -class TestVQModelTraining(VQModelTesterConfig, TrainingTesterMixin): - """Training tests for VQModel.""" - - @pytest.mark.skip("Test not supported.") - def test_training(self): - super().test_training() - - @pytest.mark.skip("Test not supported.") - def test_training_with_ema(self): - super().test_training_with_ema() - - -class TestVQModelMemory(VQModelTesterConfig, MemoryTesterMixin): - """Memory optimization tests for VQModel.""" - - -class TestVQModelSlicingTiling(VQModelTesterConfig, NewAutoencoderTesterMixin): - """Slicing and tiling tests for VQModel.""" + def test_vector_quantizer_logs_remap_configuration(self, tmp_path): + remap_path = tmp_path / "used.npy" + np.save(remap_path, np.array([0, 2, 4], dtype=np.int64)) + + with self.assertLogs("diffusers.models.autoencoders.vae", level="INFO") as captured_logs: + VectorQuantizer(n_e=8, vq_embed_dim=4, beta=0.25, remap=str(remap_path), unknown_index="extra") + + assert any( + "Remapping 8 indices to 4 indices. Using 3 for unknown indices." in message + for message in captured_logs.output + ) + + +class TestVQModelTraining(VQModelTesterConfig, TrainingTesterMixin): + """Training tests for VQModel.""" + + @pytest.mark.skip("Test not supported.") + def test_training(self): + super().test_training() + + @pytest.mark.skip("Test not supported.") + def test_training_with_ema(self): + super().test_training_with_ema() + + +class TestVQModelMemory(VQModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for VQModel.""" + + +class TestVQModelSlicingTiling(VQModelTesterConfig, NewAutoencoderTesterMixin): + """Slicing and tiling tests for VQModel."""