Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,21 @@ def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
return tuples

gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype

if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype
last_t = None
for t in gen:
last_t = t
if t[1].is_floating_point():
return t[1].dtype

if last_t is not None:
# fallback to the last dtype found via __dict__ inspection
return last_t[1].dtype

raise ValueError(
f"Could not determine the dtype of {parameter.__class__.__name__}: no parameters, buffers, or tensor "
"attributes were found. If you are using nn.DataParallel, make sure the module is moved to a device "
"before wrapping it (e.g. model.to('cuda') before DataParallel(model))."
)


@contextmanager
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def __init__(
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)

sigmas = timesteps / num_train_timesteps

# Store sigma_min / sigma_max from the *unshifted* linear schedule so that
# set_timesteps can use them as the raw [0, 1] bounds when regenerating the
# sigma grid. If they were stored after shifting, set_timesteps would feed
# already-shifted values back through the shift formula a second time,
# producing a doubly-shifted (and therefore wrong) sigma schedule (#13243).
self.sigma_min = sigmas[-1].item()
self.sigma_max = sigmas[0].item()

if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
Expand All @@ -140,8 +149,6 @@ def __init__(
self._shift = shift

self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()

@property
def shift(self):
Expand Down
27 changes: 27 additions & 0 deletions tests/schedulers/test_scheduler_flow_map_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,30 @@ def test_scale_noise_endpoints(self):
torch.testing.assert_close(scheduler.scale_noise(sample, zero_t, noise), sample)
full_t = torch.tensor([float(scheduler.config.num_train_timesteps)])
torch.testing.assert_close(scheduler.scale_noise(sample, full_t, noise), noise)

def test_set_timesteps_no_double_shift(self):
"""set_timesteps must not apply the shift formula twice (regression #13243).

When sigma_min/sigma_max were stored *after* shifting in __init__, calling
set_timesteps fed already-shifted values back through the shift formula a
second time. After the fix the schedule produced by set_timesteps must be
identical to the one built in __init__ for the same number of steps.
"""
shift = 3.0
n = 1000
scheduler = self.scheduler_class(**self.get_default_config(shift=shift))

# The sigmas stored in __init__ — these are the ground-truth shifted values.
init_sigmas = scheduler.sigmas[:-1] # drop terminal 0 added by set_timesteps

scheduler.set_timesteps(num_inference_steps=n)
inferred_sigmas = scheduler.sigmas[:-1]

self.assertEqual(len(init_sigmas), len(inferred_sigmas))
for i, (s_init, s_infer) in enumerate(zip(init_sigmas, inferred_sigmas)):
self.assertAlmostEqual(
s_init.item(),
s_infer.item(),
places=5,
msg=f"sigma mismatch at index {i}: init={s_init:.6f} vs set_timesteps={s_infer:.6f}",
)
Loading