-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[Bug] train_dreambooth_lora_qwen_image.py crashes with --with_prior_preservation due to tensor concatenation errors #13386
Description
Describe the bug
When running the train_dreambooth_lora_qwen_image.py script with the --with_prior_preservation flag, the training crashes during the text embedding extraction phase. There are two distinct bugs related to tensor concatenation at L1323 and L1324.
Bug 1: Shape mismatch for prompt_embeds (L1323)
RuntimeError: Sizes of tensors must match except in dimension 0.
QwenImage's _get_qwen_prompt_embeds() pads sequences dynamically to the longest prompt in the current call, rather than a fixed global maximum length. If the instance_prompt (e.g., "sks stuffed animal") and class_prompt (e.g., "photo of a stuffed animal") are encoded separately, their resulting prompt_embeds can have different sequence lengths (e.g., dimension 1 is 9 vs. 12). Directly using torch.cat([prompt_embeds, class_prompt_embeds], dim=0) causes a dimension mismatch.
Related code in Diffusers:
txt_tokens = self.tokenizer( max_seq_len = max([e.size(0) for e in split_hidden_states])
Bug 2: NoneType concatenation for prompt_embeds_mask (L1324)
TypeError: expected Tensor as element 0 in argument 0, but got NoneType
During encoding, if the attention mask consists entirely of 1s (meaning all text tokens are valid and no explicit masking is required), the QwenImage pipeline actively folds the mask into None (see pipeline_qwenimage.py:253). Consequently, both prompt_embeds_mask and class_prompt_embeds_mask are returned as None. Passing a list of None types to torch.cat throws a TypeError.
Related code in Diffusers:
if prompt_embeds is None:
Note: I have submitted a PR #13387 to fix these issues.
Solution
To resolve these issues, the tensor concatenation logic needs to handle dynamic padding and missing masks dynamically:
- Reconstruct Masks: If a mask is missing (
None), reconstruct it as an all-ones tensor based on its corresponding embedding's sequence length. - Pad to Max Length: If the sequence lengths differ, pad the shorter tensor(s) to match the longest sequence.
- Fold Back to
None: After concatenation, if the merged mask consists entirely of 1s, fold it back toNone.
Reproduction
Command to reproduce:
accelerate launch --config_file ./default_config.yaml train_dreambooth_lora_qwen_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="sks stuffed animal" \
--class_prompt="photo of a stuffed animal" \
--class_data_dir=$CLASS_DIR \
--with_prior_preservation \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--use_8bit_adam \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1500 \
--checkpointing_steps=100 \
--cache_latents \
--gradient_checkpointing \
--seed="42"
Logs
**Traceback for Bug 1 (Sequence Length Mismatch):**
[rank0]: Traceback (most recent call last):
[rank0]: File "train_dreambooth_lora_qwen_image.py", line 1704, in <module>
[rank0]: main(args)
[rank0]: File "train_dreambooth_lora_qwen_image.py", line 1323, in main
[rank0]: prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
[rank0]: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list.
**Traceback for Bug 2 (NoneType Concatenation):**
_(Note: If you patch Bug 1 to pad the tensors, the script will immediately hit this second error)_
[rank0]: Traceback (most recent call last):
[rank0]: File "train_dreambooth_lora_qwen_image.py", line 1713, in <module>
[rank0]: main(args)
[rank0]: File "train_dreambooth_lora_qwen_image.py", line 1333, in main
[rank0]: prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
[rank0]: TypeError: expected Tensor as element 0 in argument 0, but got NoneTypeSystem Info
- 🤗 Diffusers version: 0.38.0.dev0
- Platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.20
- PyTorch version (GPU?): 2.5.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.8.0
- Transformers version: 5.4.0
- Accelerate version: 1.13.0
- PEFT version: 0.18.1
- Bitsandbytes version: 0.49.2
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB - Using GPU in script?:
- Using distributed or parallel set-up in script?: