Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7dd7d9c
[Discrete Diffusion] Add DFlash pipeline
kashif May 8, 2026
715ca50
[Docs] flesh out DFlash pipeline + scheduler pages
kashif May 8, 2026
4c0e3dd
[DFlash] fix train_dflash position_ids + clarify trust_remote_code
kashif May 8, 2026
cd0ce7b
[DFlash] remove spec_generate fast path; explicit loop handles all ta…
kashif May 8, 2026
7bc8161
Merge branch 'main' into add-dflash-pipeline
kashif May 9, 2026
a70e329
[DFlash] add num_timesteps property for parity with LLaDA2
kashif May 9, 2026
471afa9
[DFlash] document examples in discrete_diffusion README
kashif May 9, 2026
cab4413
[DFlash] align train recipe with paper
kashif May 10, 2026
3478781
Merge branch 'main' into add-dflash-pipeline
kashif May 10, 2026
9a5fb11
[DFlash] cite paper in docs
kashif May 10, 2026
4f5688b
Merge branch 'main' into add-dflash-pipeline
kashif May 11, 2026
0481c84
[DFlash] simplify pipeline and scheduler
kashif May 13, 2026
e8ed08b
[DFlash] remove hybrid model cache support
kashif May 14, 2026
1e7f43b
[DFlash] pass model config to DynamicCache for hybrid target models
kashif May 14, 2026
9036539
Merge branch 'main' into add-dflash-pipeline
kashif May 18, 2026
83e0d1e
Merge branch 'main' into add-dflash-pipeline
kashif May 20, 2026
d1214a7
Merge branch 'main' into add-dflash-pipeline
kashif May 26, 2026
ae8bfd2
Merge branch 'add-dflash-pipeline' of https://github.com/kashif/diffu…
kashif May 26, 2026
a2f736b
[DFlash] multi-anchor sparse training via FlexAttention
kashif May 26, 2026
e486a1e
[DFlash] cache torch.compile of create_block_mask
kashif May 26, 2026
b57915e
[DFlash] add unwrap_model helper for Accelerate compile compatibility
kashif May 26, 2026
65e7a2b
[DFlash] flesh out target-regen recipe in the module docstring
kashif May 26, 2026
1df25be
[DFlash] add regenerate_dflash_data.py + jsonl train support + README…
kashif May 26, 2026
3bbaf58
Merge branch 'main' into add-dflash-pipeline
kashif May 31, 2026
1cf666f
feat(dflash): per-sample anchors, block_keep_mask, min-token filter
kashif May 31, 2026
73cf1e4
docs(dflash): add Returns section to DFlashPipeline.__call__
kashif May 31, 2026
8539942
Merge branch 'main' into add-dflash-pipeline
kashif Jun 4, 2026
97d527e
Merge branch 'main' into add-dflash-pipeline
kashif Jun 9, 2026
99eaf71
[DFlash] fix int32 gather bug, non-contiguous view; fix docs
kashif Jun 9, 2026
8d8293f
[DFlash] doc-builder style fix
kashif Jun 9, 2026
2e71bbe
[DFlash] reveal noise tokens one block at a time in visualization
kashif Jun 9, 2026
99ace9d
[DFlash] fix LLaDA2 display: correct transfer_index offset, add refin…
kashif Jun 9, 2026
1b90033
[DFlash] remove dead block_size constructor param from TokenDisplay
kashif Jun 9, 2026
a5d6db8
[DFlash] move visualization to examples/discrete_diffusion/
kashif Jun 9, 2026
a425b07
[DFlash] fix import formatting in examples/discrete_diffusion/
kashif Jun 9, 2026
8cd1455
[DFlash] remove dev scripts superseded by examples/discrete_diffusion/
kashif Jun 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,8 @@
title: Z-Image
title: Image
- sections:
- local: api/pipelines/dflash
title: DFlash
- local: api/pipelines/llada2
title: LLaDA2
title: Text
Expand Down Expand Up @@ -706,6 +708,8 @@
title: DDPMScheduler
- local: api/schedulers/deis
title: DEISMultistepScheduler
- local: api/schedulers/dflash_token_diffusion
title: DFlashTokenDiffusionScheduler
- local: api/schedulers/multistep_dpm_solver_inverse
title: DPMSolverMultistepInverse
- local: api/schedulers/multistep_dpm_solver
Expand Down
90 changes: 90 additions & 0 deletions docs/source/en/api/pipelines/dflash.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# DFlash

[DFlash: Block Diffusion for Flash Speculative Decoding](https://huggingface.co/papers/2602.06036) is by Jian Chen, Yesheng Liang, and Zhijian Liu.

The abstract from the paper is:

*Autoregressive large language models (LLMs) deliver strong performance but require inherently sequential decoding, leading to high inference latency and poor GPU utilization. Speculative decoding mitigates this bottleneck by using a fast draft model whose outputs are verified in parallel by the target LLM. However, existing methods still rely on autoregressive drafting, which remains sequential and constrains practical speedups. Diffusion LLMs offer a promising alternative by enabling parallel generation, but current diffusion models typically underperform compared with autoregressive models. In this paper, we introduce DFlash, a speculative decoding framework that employs a lightweight block diffusion model for parallel drafting. We show that speculative decoding provides a natural and effective setting for diffusion models. By generating draft tokens in a single forward pass, DFlash enables efficient drafting, and by conditioning the draft model on context features extracted from the target model, it achieves high-quality drafts with higher acceptance rates. Experiments show that DFlash achieves over 6× lossless acceleration across a range of models and tasks, delivering up to 2.5× higher speedup than the state-of-the-art speculative decoding method EAGLE-3.*

`DFlashPipeline` ties the two models together: prefill on the target, draft a block, verify against the target's
posterior via [`DFlashTokenDiffusionScheduler`], commit the accepted prefix and the next-token resample, and repeat
until `max_new_tokens` or a stop token. Pretrained draft/target pairs are available in the
[z-lab/dflash collection](https://huggingface.co/collections/z-lab/dflash); the canonical pair is
`z-lab/Qwen3-8B-DFlash-b16` with `Qwen/Qwen3-8B`.

## Usage

```py
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer

from diffusers import DFlashPipeline

# Draft ships custom modeling code via `auto_map` — `trust_remote_code=True` is required.
draft = AutoModel.from_pretrained(
"z-lab/Qwen3-8B-DFlash-b16", trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
)
target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B", dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")

pipe = DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer)
output = pipe(
prompt="What is 2 + 2? Answer in one sentence.",
max_new_tokens=256,
temperature=0.0,
chat_template_kwargs={"enable_thinking": False},
)
print(output.texts[0])
```

> **Note:** Qwen3 is a reasoning model and generates a `<think>...</think>` block before the answer. Pass
> `chat_template_kwargs={"enable_thinking": False}` to suppress thinking mode, or increase `max_new_tokens`
> (e.g. `8192`) when you want the full reasoning trace.

`DFlashPipeline` currently runs `batch_size=1` only. Multi-prompt batching requires per-row partial-accept tracking
and is not yet supported.

## Hybrid-attention targets

Target models with linear-attention layers (e.g. Qwen3.5's gated-delta-net) are not yet supported. For those
targets, `DynamicCache.crop()` silently no-ops on linear-attention layers, which would leak rejected speculative
tokens into the recurrent state. The current pipeline targets full-attention models only (e.g. `Qwen/Qwen3-8B`).

## Callbacks

Callbacks run after each block-verify step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
included in `callback_kwargs`. Allowed keys: `block_output_ids` (the drafted block), `draft_logits`,
`accepted_length`, `next_token`, and `output_ids` (the running output buffer). Return `{"output_ids": ...}` from the
callback to replace the buffer.

```py
def on_step_end(pipe, step, timestep, callback_kwargs):
output_ids = callback_kwargs["output_ids"]
return {"output_ids": output_ids}

out = pipe(
prompt="...",
callback_on_step_end=on_step_end,
callback_on_step_end_tensor_inputs=["output_ids"],
)
```

## DFlashPipeline
[[autodoc]] DFlashPipeline
- all
- __call__

## DFlashPipelineOutput
[[autodoc]] pipelines.DFlashPipelineOutput
24 changes: 24 additions & 0 deletions docs/source/en/api/schedulers/dflash_token_diffusion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# DFlashTokenDiffusionScheduler

[`DFlashTokenDiffusionScheduler`] implements the verification step for DFlash-style block-diffusion speculative
decoding. It samples a posterior block from the target logits, computes the acceptance length as the longest prefix
where the draft proposal matches the posterior, and exposes the resampled `next_token` for the first rejected
position. Used by [`DFlashPipeline`].

## DFlashTokenDiffusionScheduler
[[autodoc]] DFlashTokenDiffusionScheduler

## DFlashTokenDiffusionSchedulerOutput
[[autodoc]] schedulers.scheduling_dflash_token_diffusion.DFlashTokenDiffusionSchedulerOutput
73 changes: 73 additions & 0 deletions examples/discrete_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,76 @@ python examples/discrete_diffusion/sample_llada2.py \
--use_chat_template \
--add_generation_prompt
```

## DFlash

[DFlash](https://huggingface.co/papers/2602.06036) ([model collection](https://huggingface.co/collections/z-lab/dflash)) is a block-diffusion **speculative decoding** scheme: a small diffusion *draft* model, conditioned on hidden features from a frozen *target* causal LM, proposes a block of tokens that the target verifies in a single forward pass. The pipeline accepts the longest matching prefix and resamples the next token at the rejection point.

### Sample

The published draft pairs with a stock target (no `trust_remote_code` for the target):

```bash
python examples/discrete_diffusion/sample_dflash.py \
--draft_model_id z-lab/Qwen3.5-4B-DFlash \
--target_model_id Qwen/Qwen3.5-4B \
--prompt "How many positive whole-number divisors does 196 have?" \
--max_new_tokens 4096
```

The draft ships a custom `DFlashDraftModel` class via `auto_map`, so the sample script loads it with `trust_remote_code=True`; the target loads as a stock Qwen3 / Qwen3.5 model. Per-draft thinking-mode defaults from the upstream model cards:

| Draft | `enable_thinking` |
|---|---|
| `z-lab/Qwen3.5-*-DFlash` | `True` |
| `z-lab/Qwen3-*-DFlash-b16` | `False` (drafts are non-thinking-trained) |

### Prepare training data

Per the DFlash paper (§5.1) the draft should be trained on responses **regenerated by the target model**, otherwise acceptance length at inference drops. Spin up the target with any OpenAI-compatible server (vLLM or SGLang both work) and re-roll a chat dataset's assistant turns:

```bash
# 1. Serve the target
vllm serve Qwen/Qwen3-4B --port 8000 --served-model-name target

# 2. Re-roll a slice (≈30 LOC asyncio helper bundled with this example)
python examples/discrete_diffusion/regenerate_dflash_data.py \
--server http://localhost:8000/v1 \
--model target \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--split "train[:1000]" \
--output_path dflash-train.jsonl \
--temperature 0.8 \
--concurrency 64
```

Use `temperature 0.7-0.8` (NOT greedy) — diversity in supervised tokens teaches the draft to handle a wider distribution at inference. Concurrency 64-128 saturates one vLLM server.

### Train

The training loop conditions the draft on intermediate target hidden states and predicts the next `block_size − 1` tokens at each of `--num_anchors` anchor positions per sequence (paper Appendix A.1 trains with 512 anchors per sequence):

```bash
accelerate launch examples/discrete_diffusion/train_dflash.py \
--draft_model_id z-lab/Qwen3-4B-DFlash-b16 \
--target_model_id Qwen/Qwen3-4B \
--dataset_name dflash-train.jsonl \
--text_column text \
--output_dir dflash-output \
--max_train_steps 1000 \
--learning_rate 6e-4 \
--num_anchors 512 \
--attention_backend flex_attention
```

Key flags beyond the standard ones:

| Flag | Notes |
|---|---|
| `--num_anchors` | Anchor blocks per sequence (paper default 512). Each anchor contributes `block_size - 1` supervised tokens. |
| `--attention_backend` | `flex_attention` (sparse, scales to N=512) or `sdpa` (dense, OOMs above ~64). |
| `--no_overlap_anchors` | Disable independent (overlapping) sampling and use non-overlapping anchors instead. Caps N at `seq_len // block_size`. |
| `--block_size 0` | Reads block size from the draft config (16 for `b16` / `Qwen3.5-*-DFlash` drafts). |

End-to-end model compilation is delegated to Accelerate — pass `--dynamo_backend inductor` to `accelerate launch` to enable it.
78 changes: 78 additions & 0 deletions examples/discrete_diffusion/regenerate_dflash_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python
# Small async helper that re-rolls assistant responses for a tiny dataset
# slice through a locally-served target model (vLLM or SGLang — both expose
# the OpenAI /v1/chat/completions API).
#
# This is a smoke-test stand-in for SpecForge's scripts/regenerate_train_data.py
# and Automodel's nemo_automodel.components.speculative.regenerate, kept tiny so
# the full pipeline (serve -> regen -> train) fits in one sbatch.

from __future__ import annotations

import argparse
import asyncio
import json
import sys

from datasets import load_dataset
from openai import AsyncOpenAI


async def regenerate(client, model, messages, max_tokens, temperature):
resp = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return resp.choices[0].message.content


async def main():
p = argparse.ArgumentParser()
p.add_argument("--server", default="http://localhost:8000/v1")
p.add_argument("--model", required=True, help="served-model-name from the vLLM/SGLang server")
p.add_argument("--dataset_name", default="wikitext")
p.add_argument("--dataset_config_name", default="wikitext-2-raw-v1")
p.add_argument("--split", default="train[:32]")
p.add_argument("--text_column", default="text")
p.add_argument("--output_path", required=True)
p.add_argument("--max_tokens", type=int, default=128)
p.add_argument("--temperature", type=float, default=0.8) # SpecForge default
p.add_argument("--concurrency", type=int, default=8)
args = p.parse_args()

ds = load_dataset(args.dataset_name, args.dataset_config_name, split=args.split)
prompts = [row[args.text_column] for row in ds if row.get(args.text_column, "").strip()]
print(f"[regen] {len(prompts)} prompts; serving={args.server}; model={args.model}", file=sys.stderr)

client = AsyncOpenAI(base_url=args.server, api_key="EMPTY")
sem = asyncio.Semaphore(args.concurrency)

async def one(i, prompt):
messages = [{"role": "user", "content": prompt[:1024]}]
async with sem:
try:
assistant = await regenerate(client, args.model, messages, args.max_tokens, args.temperature)
except Exception as exc:
print(f"[regen] sample {i} failed: {exc}", file=sys.stderr)
return None
# Emit both the flat "text" column (so train_dflash.py can tokenize directly)
# and "conversations" (so downstream consumers that expect chat format work too).
return {
"id": i,
"text": assistant,
"conversations": messages + [{"role": "assistant", "content": assistant}],
}

results = await asyncio.gather(*[one(i, p) for i, p in enumerate(prompts)])
rows = [r for r in results if r is not None]

with open(args.output_path, "w") as f:
for row in rows:
f.write(json.dumps(row) + "\n")
print(f"[regen] wrote {len(rows)}/{len(prompts)} rows -> {args.output_path}", file=sys.stderr)


if __name__ == "__main__":
asyncio.run(main())
Loading
Loading