diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 01185afda382..8b908612e0e4 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -42,6 +42,8 @@ convert_flux_transformer_checkpoint_to_diffusers, convert_hidream_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers, + convert_hunyuanimage_transformer_to_diffusers, + convert_ideogram4_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ltx2_audio_vae_to_diffusers, @@ -136,6 +138,10 @@ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "HunyuanImageTransformer2DModel": { + "checkpoint_mapping_fn": convert_hunyuanimage_transformer_to_diffusers, + "default_subfolder": "transformer", + }, "HunyuanVideoTransformer3DModel": { "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, "default_subfolder": "transformer", @@ -176,6 +182,10 @@ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "Ideogram4Transformer2DModel": { + "checkpoint_mapping_fn": convert_ideogram4_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, "QwenImageTransformer2DModel": { "checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 296f32f891f0..0a53fcb54a23 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2900,6 +2900,130 @@ def update_state_dict_(state_dict, old_key, new_key): return checkpoint +def convert_hunyuanimage_transformer_to_diffusers(checkpoint, **kwargs): + def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3584 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "byt5_in.fc1": "context_embedder_2.linear_1", + "byt5_in.fc2": "context_embedder_2.linear_2", + "byt5_in.fc3": "context_embedder_2.linear_3", + "byt5_in.layernorm": "context_embedder_2.norm", + "img_in": "x_embedder", + "time_in.mlp.0": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_guidance_embed.timestep_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + } + + def update_state_dict_(state_dict, old_key, new_key): + state_dict[new_key] = state_dict.pop(old_key) + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(checkpoint, key, new_key) + + for key in list(checkpoint.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, checkpoint) + + return checkpoint + + def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} state_dict_keys = list(checkpoint.keys()) @@ -4180,3 +4304,23 @@ def convert_ernie_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) return checkpoint + + +def convert_ideogram4_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + # Original/GGUF Ideogram 4 checkpoints fuse the attention projection as `attention.qkv` and name + # the output projection `attention.o`; diffusers uses split to_q/to_k/to_v and to_out.0. Every + # other key already matches the diffusers layout. + if not any(k.endswith("attention.qkv.weight") for k in checkpoint): + return checkpoint + converted_state_dict = {} + num_layers = max(int(k.split(".")[1]) for k in checkpoint if k.startswith("layers.")) + 1 + for i in range(num_layers): + q, k, v = torch.chunk(checkpoint.pop(f"layers.{i}.attention.qkv.weight"), 3, dim=0) + converted_state_dict[f"layers.{i}.attention.to_q.weight"] = q + converted_state_dict[f"layers.{i}.attention.to_k.weight"] = k + converted_state_dict[f"layers.{i}.attention.to_v.weight"] = v + converted_state_dict[f"layers.{i}.attention.to_out.0.weight"] = checkpoint.pop( + f"layers.{i}.attention.o.weight" + ) + converted_state_dict.update(checkpoint) + return converted_state_dict