From bc6fae687352d04e0ce103f4c1d4a859821b4244 Mon Sep 17 00:00:00 2001 From: Alexey Zolotenkov Date: Wed, 3 Jun 2026 16:18:02 +0200 Subject: [PATCH 1/2] Add Flux2 DreamBooth prior preservation tests --- .../dreambooth/test_dreambooth_lora_flux2.py | 27 +++++++++++++++++++ .../test_dreambooth_lora_flux2_klein.py | 27 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_flux2.py b/examples/dreambooth/test_dreambooth_lora_flux2.py index 80a0b502f9a2..168e4f788bdc 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux2.py +++ b/examples/dreambooth/test_dreambooth_lora_flux2.py @@ -111,6 +111,33 @@ def test_dreambooth_lora_latent_caching(self): starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_flux2_prior_preservation_batch_size_two(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --with_prior_preservation + --class_data_dir {self.instance_data_dir} + --class_prompt dog + --num_class_images 2 + --resolution 64 + --train_batch_size 2 + --gradient_accumulation_steps 1 + --max_train_steps 1 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + def test_dreambooth_lora_layers(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/examples/dreambooth/test_dreambooth_lora_flux2_klein.py b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py index 0e5506e1a3eb..1d715d6cb519 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py @@ -111,6 +111,33 @@ def test_dreambooth_lora_latent_caching(self): starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_flux2_klein_prior_preservation_batch_size_two(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --with_prior_preservation + --class_data_dir {self.instance_data_dir} + --class_prompt dog + --num_class_images 2 + --resolution 64 + --train_batch_size 2 + --gradient_accumulation_steps 1 + --max_train_steps 1 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + def test_dreambooth_lora_layers(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From b6636d2c74845a59251527cded63762469956fd6 Mon Sep 17 00:00:00 2001 From: Alexey Zolotenkov Date: Wed, 3 Jun 2026 19:52:26 +0200 Subject: [PATCH 2/2] Use isolated class dirs in Flux2 prior tests --- examples/dreambooth/test_dreambooth_lora_flux2.py | 8 +++++++- examples/dreambooth/test_dreambooth_lora_flux2_klein.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_lora_flux2.py b/examples/dreambooth/test_dreambooth_lora_flux2.py index 168e4f788bdc..93f683585a18 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux2.py +++ b/examples/dreambooth/test_dreambooth_lora_flux2.py @@ -16,6 +16,7 @@ import json import logging import os +import shutil import sys import tempfile @@ -113,13 +114,18 @@ def test_dreambooth_lora_latent_caching(self): def test_dreambooth_lora_flux2_prior_preservation_batch_size_two(self): with tempfile.TemporaryDirectory() as tmpdir: + class_data_dir = os.path.join(tmpdir, "class_data") + os.makedirs(class_data_dir) + for image_name in os.listdir(self.instance_data_dir)[:2]: + shutil.copy(os.path.join(self.instance_data_dir, image_name), class_data_dir) + test_args = f""" {self.script_path} --pretrained_model_name_or_path {self.pretrained_model_name_or_path} --instance_data_dir {self.instance_data_dir} --instance_prompt {self.instance_prompt} --with_prior_preservation - --class_data_dir {self.instance_data_dir} + --class_data_dir {class_data_dir} --class_prompt dog --num_class_images 2 --resolution 64 diff --git a/examples/dreambooth/test_dreambooth_lora_flux2_klein.py b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py index 1d715d6cb519..b12ae51bc3f6 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py @@ -16,6 +16,7 @@ import json import logging import os +import shutil import sys import tempfile @@ -113,13 +114,18 @@ def test_dreambooth_lora_latent_caching(self): def test_dreambooth_lora_flux2_klein_prior_preservation_batch_size_two(self): with tempfile.TemporaryDirectory() as tmpdir: + class_data_dir = os.path.join(tmpdir, "class_data") + os.makedirs(class_data_dir) + for image_name in os.listdir(self.instance_data_dir)[:2]: + shutil.copy(os.path.join(self.instance_data_dir, image_name), class_data_dir) + test_args = f""" {self.script_path} --pretrained_model_name_or_path {self.pretrained_model_name_or_path} --instance_data_dir {self.instance_data_dir} --instance_prompt {self.instance_prompt} --with_prior_preservation - --class_data_dir {self.instance_data_dir} + --class_data_dir {class_data_dir} --class_prompt dog --num_class_images 2 --resolution 64