-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Discrete diffusion in diffusers #12911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks for the PR! Some preliminary design questions and comments:
|
| self.register_to_config( | ||
| seq_len=seq_len, | ||
| num_inference_steps=num_inference_steps, | ||
| inject_start_token=inject_start_token, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally we don't register default __call__ arguments to the config, but rather set them as default arguments to the __call__ method:
diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Lines 744 to 752 in d4f97d1
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| height: int = 512, | |
| width: int = 768, | |
| num_frames: int = 121, | |
| frame_rate: float = 24.0, | |
| num_inference_steps: int = 40, |
| *, | ||
| batch_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diffusers pipelines usually don't set __call__ arguments to be keyword-only. (That's not to say that there are no arguments for it, but because other pipelines allow positional arguments I think the expectation is that discrete diffusion pipelines will allow them as well.)
| if seq_len is None: | ||
| seq_len = int(self.config.seq_len) | ||
| if num_inference_steps is None: | ||
| num_inference_steps = int(self.config.num_inference_steps) | ||
| if inject_start_token is None: | ||
| inject_start_token = bool(self.config.inject_start_token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up on #12911 (comment), this logic could be removed if we don't register default arguments to the config.
| if infill_mask is not None: | ||
| if infill_mask.shape != (batch_size, seq_len): | ||
| raise ValueError( | ||
| f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think input checking and exceptions should be moved to a check_inputs method, which is the usual practice for diffusers pipelines:
diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py
Lines 686 to 693 in d4f97d1
| def check_inputs( | |
| self, | |
| prompt, | |
| height, | |
| width, | |
| prompt_embeds=None, | |
| callback_on_step_end_tensor_inputs=None, | |
| ): |
| return int(token_id) | ||
| return None | ||
|
|
||
| def _init_latents( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually name methods which sample latents from the prior distribution prepare_latents:
| def prepare_latents( |
| if hasattr(self.scheduler, "forward_process") and getattr(self.scheduler, "forward_process") == "uniform": | ||
| # Uniform prior over token IDs. Mirror scheduler's exclude-mask behavior. | ||
| if getattr(self.scheduler, "exclude_mask_from_uniform", False) and hasattr( | ||
| self.scheduler, "_sample_uniform_tokens" | ||
| ): | ||
| return self.scheduler._sample_uniform_tokens( | ||
| torch.Size((batch_size, seq_len)), | ||
| device=device, | ||
| dtype=torch.long, | ||
| generator=generator, | ||
| ) | ||
| vocab_size = int(getattr(self.scheduler, "vocab_size", 0)) | ||
| if vocab_size <= 0: | ||
| raise ValueError("Scheduler must define `vocab_size` for uniform prior sampling.") | ||
| return torch.randint( | ||
| 0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long, generator=generator | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: maybe it would be cleaner to define a scheduler method called (say) sample_prior which samples from the prior distribution based on the configured forward_process? So if self.forward_process == "uniform", we would call _sample_uniform_tokens under the hood in sample_prior to sample from a uniform prior distribution.
I think this would allow for more graceful support of other possible forward processes, and make the pipeline code cleaner (as most of the logic would be handled inside the scheduler).
| timesteps = torch.linspace( | ||
| self.num_train_timesteps - 1, 0, self.num_inference_steps, dtype=torch.float32 | ||
| ).round() | ||
| self.timesteps = timesteps.to(dtype=torch.long, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: we could pre-compute the alpha schedule here once the timestep discretization is fixed. This could be a little more efficient since I think we use the alpha for each timestep twice in the denoising loop (once as alpha_t and once as alpha_prev).
| if self.forward_process != "absorbing": | ||
| raise ValueError(f"Unsupported forward process for `step()`: {self.forward_process!r}") | ||
|
|
||
| # p_denoise = (alpha_prev - alpha_t) / (1 - alpha_t) | ||
| denom = (1.0 - alpha_t).clamp_min(torch.finfo(torch.float32).eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I find this code structure a little confusing, and think that something like
if self.forward_process == "uniform":
...
elif self.forward_process == "absorbing":
...
else:
raise ValueError(f"Unsupported forward process for `step()`: {self.forward_process!r}")
if not return_dict:
return (x_prev,)
return TokenDiffusionSchedulerOutput(prev_sample=x_prev)would be more readable.
| from .scheduling_token_diffusion import TokenDiffusionScheduler, TokenDiffusionSchedulerOutput | ||
|
|
||
|
|
||
| class BlockTokenDiffusionScheduler(TokenDiffusionScheduler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we would generally prefer a single file design rather than inheriting from TokenDiffusionScheduler. CC @yiyixuxu
A possible alternative would be to move the block_mask logic into the add_noise and step methods of TokenDiffusionScheduler, retaining the current default of noising/denoising over the entire input if block_mask is None.
What does this PR do?
Add experimental support for discrete token diffusion methods and pipeline
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.