Skip to content

Add ernie image#13432

Open
HsiaWinter wants to merge 12 commits intohuggingface:mainfrom
HsiaWinter:add-ernie-image
Open

Add ernie image#13432
HsiaWinter wants to merge 12 commits intohuggingface:mainfrom
HsiaWinter:add-ernie-image

Conversation

@HsiaWinter
Copy link
Copy Markdown

What does this PR do?

We have introduced a new text-to-image model called ERNIE-Image, which will soon be open-sourced to the community. This PR includes the model architecture definition, the pipeline, as well as the related documentation and test files.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot added documentation Improvements or additions to documentation models tests utils pipelines size/L PR with diff > 200 LOC labels Apr 8, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
i left some feedbacks

def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True):
super().__init__()
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps)
self.self_attention = Attention(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I recreate a custom attention class

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous()


class TimestepEmbedding(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Author

@HsiaWinter HsiaWinter Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix


return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,)

def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh, we are padding the text embeddings here, is it possible to move this outside of the model, in to the pipeline? e.g. you can pass image_ids, text_ids and text_seq_lens instead
i think it would affect torch.compile too if we pad text embeddings inside the transformer

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

return out.float()


class EmbedND3(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class EmbedND3(nn.Module):
class ErnieImageEmbedND3(nn.Module):

can we follow our naming conventions and add the ErnieImage prefix every where?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

self.vae_scale_factor = 16 # VAE downsample factor

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you write a custom from_pretrained method here? is there any reason, you could not use the from_pretrained in inhrited from DiffusionPipeline?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Comment on lines +186 to +189
if hasattr(self.pe, "_hf_hook") and hasattr(self.pe._hf_hook, "execution_device"):
pe_device = self.pe._hf_hook.execution_device
else:
pe_device = device
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hasattr(self.pe, "_hf_hook") and hasattr(self.pe._hf_hook, "execution_device"):
pe_device = self.pe._hf_hook.execution_device
else:
pe_device = device
pe_device = device or self._execution_deivce

this is basically self._execution_device, no? https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py#L1136

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt)

# CFG with negative prompt
do_cfg = guidance_scale > 1.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a do_classifier_free_guidance property instead?
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py#L590

Suggested change
do_cfg = guidance_scale > 1.0

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix


# CFG with negative prompt
do_cfg = guidance_scale > 1.0
if do_cfg:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if do_cfg:
if self.do_classifier_free_guidance:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

@yiyixuxu yiyixuxu requested a review from dg845 April 8, 2026 09:02
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 8, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!
i left a few more comments

return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]


class PatchEmbedDynamic(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a prefix to these names too?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)


class FeedForward(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FeedForward(nn.Module):
class ErnieImageFeedForward(nn.Module):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix


def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
B, D, Hp, Wp = x.shape
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
B, D, Hp, Wp = x.shape
batch_size, dim, height, width = x.shape

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
B, D, Hp, Wp = x.shape
return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous()
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()

we prefer to use more descriptive variable names

self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_hidden_states: torch.Tensor | None = None,

it's not needed since we are single stream here no?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 9, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! left two small comments
let's merge this soon

sample = sample.to(self.time_embedding.linear_1.weight.dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)]
for layer in self.layers:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@unittest.skipIf(
IS_GITHUB_ACTIONS,
reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh I think we shoulld not skip the test here
let's not have torch.empty() during init then? (I didn't find any torch.empty() there actutlly)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Apr 9, 2026

@claude can you do a review here also? please keep these 3 note in mind as well during your review

  1. compare the Ernie model/pipeline to others like Qwen/Flux —let us know if there is any significant inconsistencies you found.
  2. if you see any unused code paths, let us know
  3. Look over the PR comments I made and check if the same patterns we caught/fixed still exist elsewhere in the code.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 9, 2026

Claude Code is working…

I'll analyze this and get back to you.

View job run

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants