Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Jan 4, 2026

What does this PR do?

Add experimental support for discrete token diffusion methods and pipeline

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif marked this pull request as draft January 4, 2026 23:34
@kashif kashif changed the title Discrete diffusion in diffuers Discrete diffusion in diffusers Jan 4, 2026
@yiyixuxu
Copy link
Collaborator

Thanks for this PR!
cc @dg845, can you take a look here? it's related to Dream 7B #12091 you are working on

@dg845
Copy link
Collaborator

dg845 commented Jan 23, 2026

Thanks for the PR! Some preliminary design questions and comments:

  1. I think it could be useful to have a natural place to implement logic which is common to discrete diffusion models. Would something like a DiscreteDiffusionPipelineMixin make sense? For example, I think _resolve_start_token_id, _normalize_prefix_ids, _top_p_filtering, etc. could be candidates as mixin methods. (A possible alternative could be to put the methods in DiffusionPipeline, but it feels a little weird to put the methods there because they aren't applicable to continuous diffusion models.) But maybe this is premature, since we might not know what logic will end up being useful for all (or most) discrete diffusion models.
    1. One motivation for this is that we often want to do semi-autoregressive (SAR) sampling for discrete diffusion models, so it would be useful to have autoregressive sampling techniques such as top-$p$ sampling, top-$k$ sampling, etc. So I think it would be nice to have a place where these methods can be implemented and tested once, and then new discrete diffusion models that support SAR sampling can have easy access to them without having to copy them every time.
  2. Similarly, would it make sense to have a TokenizerTextProcessor class which handles text pre-processing and and post-processing, analogous to how VaeImageProcessor handles image pre- and post-processing? It's probably less necessary as we don't need to do as much normalization as for images, but I could see this being useful for handling e.g. chat templates like in the SDAR and LLaDA 2 pipelines.
    1. As an aside, this could also be useful for existing (continuous) diffusion models, some of which have pretty involved text processing, such as pipelines like SanaPipeline that use a _text_preprocessing method:
      # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
      def _text_preprocessing(self, text, clean_caption=False):
  3. Currently it looks like the pipelines only support denoising models with a transformers-like interface. But we would probably want to implement some discrete diffusion transformers in diffusers, which currently doesn't enforce that interface. So I think we should think about how we can handle both cases gracefully in discrete diffusion pipelines. (One solution could be to simply adopt the transformers interface for all discrete denoising models in diffusers, but that could be unnecessarily restrictive.)

Comment on lines +73 to +77
self.register_to_config(
seq_len=seq_len,
num_inference_steps=num_inference_steps,
inject_start_token=inject_start_token,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we don't register default __call__ arguments to the config, but rather set them as default arguments to the __call__ method:

def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,

Comment on lines +148 to +149
*,
batch_size: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diffusers pipelines usually don't set __call__ arguments to be keyword-only. (That's not to say that there are no arguments for it, but because other pipelines allow positional arguments I think the expectation is that discrete diffusion pipelines will allow them as well.)

Comment on lines +185 to +190
if seq_len is None:
seq_len = int(self.config.seq_len)
if num_inference_steps is None:
num_inference_steps = int(self.config.num_inference_steps)
if inject_start_token is None:
inject_start_token = bool(self.config.inject_start_token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on #12911 (comment), this logic could be removed if we don't register default arguments to the config.

Comment on lines +217 to +221
if infill_mask is not None:
if infill_mask.shape != (batch_size, seq_len):
raise ValueError(
f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think input checking and exceptions should be moved to a check_inputs method, which is the usual practice for diffusers pipelines:

def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):

return int(token_id)
return None

def _init_latents(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually name methods which sample latents from the prior distribution prepare_latents:

Comment on lines +102 to +118
if hasattr(self.scheduler, "forward_process") and getattr(self.scheduler, "forward_process") == "uniform":
# Uniform prior over token IDs. Mirror scheduler's exclude-mask behavior.
if getattr(self.scheduler, "exclude_mask_from_uniform", False) and hasattr(
self.scheduler, "_sample_uniform_tokens"
):
return self.scheduler._sample_uniform_tokens(
torch.Size((batch_size, seq_len)),
device=device,
dtype=torch.long,
generator=generator,
)
vocab_size = int(getattr(self.scheduler, "vocab_size", 0))
if vocab_size <= 0:
raise ValueError("Scheduler must define `vocab_size` for uniform prior sampling.")
return torch.randint(
0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long, generator=generator
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: maybe it would be cleaner to define a scheduler method called (say) sample_prior which samples from the prior distribution based on the configured forward_process? So if self.forward_process == "uniform", we would call _sample_uniform_tokens under the hood in sample_prior to sample from a uniform prior distribution.

I think this would allow for more graceful support of other possible forward processes, and make the pipeline code cleaner (as most of the logic would be handled inside the scheduler).

timesteps = torch.linspace(
self.num_train_timesteps - 1, 0, self.num_inference_steps, dtype=torch.float32
).round()
self.timesteps = timesteps.to(dtype=torch.long, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: we could pre-compute the alpha schedule here once the timestep discretization is fixed. This could be a little more efficient since I think we use the alpha for each timestep twice in the denoising loop (once as alpha_t and once as alpha_prev).

Comment on lines +441 to +445
if self.forward_process != "absorbing":
raise ValueError(f"Unsupported forward process for `step()`: {self.forward_process!r}")

# p_denoise = (alpha_prev - alpha_t) / (1 - alpha_t)
denom = (1.0 - alpha_t).clamp_min(torch.finfo(torch.float32).eps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I find this code structure a little confusing, and think that something like

        if self.forward_process == "uniform":
            ...
        elif self.forward_process == "absorbing":
            ...
        else:
             raise ValueError(f"Unsupported forward process for `step()`: {self.forward_process!r}")

        if not return_dict:
            return (x_prev,)
        return TokenDiffusionSchedulerOutput(prev_sample=x_prev)

would be more readable.

from .scheduling_token_diffusion import TokenDiffusionScheduler, TokenDiffusionSchedulerOutput


class BlockTokenDiffusionScheduler(TokenDiffusionScheduler):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would generally prefer a single file design rather than inheriting from TokenDiffusionScheduler. CC @yiyixuxu

A possible alternative would be to move the block_mask logic into the add_noise and step methods of TokenDiffusionScheduler, retaining the current default of noising/denoising over the entire input if block_mask is None.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants