Skip to content

[Bug] train_dreambooth_lora_qwen_image.py crashes with --with_prior_preservation due to tensor concatenation errors #13386

@chenyangzhu1

Description

@chenyangzhu1

Describe the bug

When running the train_dreambooth_lora_qwen_image.py script with the --with_prior_preservation flag, the training crashes during the text embedding extraction phase. There are two distinct bugs related to tensor concatenation at L1323 and L1324.

Bug 1: Shape mismatch for prompt_embeds (L1323)

RuntimeError: Sizes of tensors must match except in dimension 0.

QwenImage's _get_qwen_prompt_embeds() pads sequences dynamically to the longest prompt in the current call, rather than a fixed global maximum length. If the instance_prompt (e.g., "sks stuffed animal") and class_prompt (e.g., "photo of a stuffed animal") are encoded separately, their resulting prompt_embeds can have different sequence lengths (e.g., dimension 1 is 9 vs. 12). Directly using torch.cat([prompt_embeds, class_prompt_embeds], dim=0) causes a dimension mismatch.

Related code in Diffusers:

Bug 2: NoneType concatenation for prompt_embeds_mask (L1324)

TypeError: expected Tensor as element 0 in argument 0, but got NoneType

During encoding, if the attention mask consists entirely of 1s (meaning all text tokens are valid and no explicit masking is required), the QwenImage pipeline actively folds the mask into None (see pipeline_qwenimage.py:253). Consequently, both prompt_embeds_mask and class_prompt_embeds_mask are returned as None. Passing a list of None types to torch.cat throws a TypeError.

Related code in Diffusers:

Note: I have submitted a PR #13387 to fix these issues.

Solution

To resolve these issues, the tensor concatenation logic needs to handle dynamic padding and missing masks dynamically:

  1. Reconstruct Masks: If a mask is missing (None), reconstruct it as an all-ones tensor based on its corresponding embedding's sequence length.
  2. Pad to Max Length: If the sequence lengths differ, pad the shorter tensor(s) to match the longest sequence.
  3. Fold Back to None: After concatenation, if the merged mask consists entirely of 1s, fold it back to None.

Reproduction

Command to reproduce:

accelerate launch --config_file ./default_config.yaml train_dreambooth_lora_qwen_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --instance_prompt="sks stuffed animal" \
  --class_prompt="photo of a stuffed animal" \
  --class_data_dir=$CLASS_DIR \
  --with_prior_preservation \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --use_8bit_adam \
  --learning_rate=2e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1500 \
  --checkpointing_steps=100 \
  --cache_latents \
  --gradient_checkpointing \
  --seed="42" 

Logs

**Traceback for Bug 1 (Sequence Length Mismatch):**


[rank0]: Traceback (most recent call last):
[rank0]:   File "train_dreambooth_lora_qwen_image.py", line 1704, in <module>
[rank0]:     main(args)
[rank0]:   File "train_dreambooth_lora_qwen_image.py", line 1323, in main
[rank0]:     prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
[rank0]: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list.



**Traceback for Bug 2 (NoneType Concatenation):**

_(Note: If you patch Bug 1 to pad the tensors, the script will immediately hit this second error)_


[rank0]: Traceback (most recent call last):
[rank0]:   File "train_dreambooth_lora_qwen_image.py", line 1713, in <module>
[rank0]:     main(args)
[rank0]:   File "train_dreambooth_lora_qwen_image.py", line 1333, in main
[rank0]:     prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
[rank0]: TypeError: expected Tensor as element 0 in argument 0, but got NoneType

System Info

  • 🤗 Diffusers version: 0.38.0.dev0
  • Platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.20
  • PyTorch version (GPU?): 2.5.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.8.0
  • Transformers version: 5.4.0
  • Accelerate version: 1.13.0
  • PEFT version: 0.18.1
  • Bitsandbytes version: 0.49.2
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA A100 80GB PCIe, 81920 MiB
    NVIDIA A100 80GB PCIe, 81920 MiB
    NVIDIA A100 80GB PCIe, 81920 MiB
    NVIDIA A100 80GB PCIe, 81920 MiB
    NVIDIA A100 80GB PCIe, 81920 MiB
    NVIDIA A100 80GB PCIe, 81920 MiB
    NVIDIA A100 80GB PCIe, 81920 MiB
    NVIDIA A100 80GB PCIe, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions