Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_hunyuanimage_transformer_to_diffusers,
convert_ideogram4_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx2_audio_vae_to_diffusers,
Expand Down Expand Up @@ -136,6 +138,10 @@
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_hunyuanimage_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
Expand Down Expand Up @@ -176,6 +182,10 @@
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"Ideogram4Transformer2DModel": {
"checkpoint_mapping_fn": convert_ideogram4_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
"default_subfolder": "transformer",
Expand Down
144 changes: 144 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2900,6 +2900,130 @@ def update_state_dict_(state_dict, old_key, new_key):
return checkpoint


def convert_hunyuanimage_transformer_to_diffusers(checkpoint, **kwargs):
def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight

def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key

if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)

def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v

def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v

def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3584

if "linear1.weight" in key:
linear1_weight = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
state_dict[f"{new_key}.attn.to_q.weight"] = q
state_dict[f"{new_key}.attn.to_k.weight"] = k
state_dict[f"{new_key}.attn.to_v.weight"] = v
state_dict[f"{new_key}.proj_mlp.weight"] = mlp

elif "linear1.bias" in key:
linear1_bias = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias

else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)

TRANSFORMER_KEYS_RENAME_DICT = {
"byt5_in.fc1": "context_embedder_2.linear_1",
"byt5_in.fc2": "context_embedder_2.linear_2",
"byt5_in.fc3": "context_embedder_2.linear_3",
"byt5_in.layernorm": "context_embedder_2.norm",
"img_in": "x_embedder",
"time_in.mlp.0": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_guidance_embed.timestep_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}

def update_state_dict_(state_dict, old_key, new_key):
state_dict[new_key] = state_dict.pop(old_key)

for key in list(checkpoint.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(checkpoint, key, new_key)

for key in list(checkpoint.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, checkpoint)

return checkpoint


def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
state_dict_keys = list(checkpoint.keys())
Expand Down Expand Up @@ -4180,3 +4304,23 @@ def convert_ernie_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

return checkpoint


def convert_ideogram4_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
# Original/GGUF Ideogram 4 checkpoints fuse the attention projection as `attention.qkv` and name
# the output projection `attention.o`; diffusers uses split to_q/to_k/to_v and to_out.0. Every
# other key already matches the diffusers layout.
if not any(k.endswith("attention.qkv.weight") for k in checkpoint):
return checkpoint
converted_state_dict = {}
num_layers = max(int(k.split(".")[1]) for k in checkpoint if k.startswith("layers.")) + 1
for i in range(num_layers):
q, k, v = torch.chunk(checkpoint.pop(f"layers.{i}.attention.qkv.weight"), 3, dim=0)
converted_state_dict[f"layers.{i}.attention.to_q.weight"] = q
converted_state_dict[f"layers.{i}.attention.to_k.weight"] = k
converted_state_dict[f"layers.{i}.attention.to_v.weight"] = v
converted_state_dict[f"layers.{i}.attention.to_out.0.weight"] = checkpoint.pop(
f"layers.{i}.attention.o.weight"
)
converted_state_dict.update(checkpoint)
return converted_state_dict
Loading