Skip to content

[Qwen3.5] dedup position ids#2102

Open
daijh wants to merge 4 commits intomicrosoft:mainfrom
daijh:qwen3.5-dedup-position_ids
Open

[Qwen3.5] dedup position ids#2102
daijh wants to merge 4 commits intomicrosoft:mainfrom
daijh:qwen3.5-dedup-position_ids

Conversation

@daijh
Copy link
Copy Markdown
Contributor

@daijh daijh commented Apr 27, 2026

Extract _make_synthetic_position_ids() to derive batch_size and sequence_length from the position_ids model input [3, B, S] instead of Shape on intermediate Q/K tensors. Use a fixed basename so make_node dedup creates the subgraph once and reuses it across all layers and Q/K calls.

  • Reduces ONNX graph size by eliminating per-layer duplicates

  • Eliminates redundant INT64 Mul nodes that cause CPU fallback on WebGPU (WebGPU Mul does not support INT64)

Copilot AI review requested due to automatic review settings April 27, 2026 07:48
@daijh daijh marked this pull request as draft April 27, 2026 07:51
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR reduces ONNX graph duplication in the Qwen3.5 builder by centralizing synthetic position_ids generation and deduplicating the resulting subgraph across all layers/attention paths, helping shrink graphs and avoid repeated INT64 ops.

Changes:

  • Extracted _make_synthetic_position_ids() to derive batch_size/sequence_length from the position_ids model input shape [3, B, S].
  • Switched mRoPE rotation to reuse a shared synthetic position_ids subgraph via a fixed basename (leveraging name-based node dedup).
  • Removed per-Q/K (and per-layer) synthetic position_ids subgraph construction from _apply_mrope_rotation().

@daijh daijh force-pushed the qwen3.5-dedup-position_ids branch from cd1a67c to c25aff1 Compare April 27, 2026 08:07
@daijh
Copy link
Copy Markdown
Contributor Author

daijh commented Apr 28, 2026

  • Total Nodes: 1509 → 1434 (75 nodes pruned)
  • WebGPU-EP MemcpyFromHost: 16 → 5 (11 host-to-device transfers eliminated)

Performance Comparison - Prefill-1024, Decode-128

Model Variant Prefill Speed (TPS) Decode Speed (TPS)
Qwen3.5-0.8B-int4 2415.4 56.5
Qwen3.5-0.8B-int4-with-optimization 2452.8 58.4

@daijh daijh requested a review from Copilot April 28, 2026 07:34
@daijh daijh marked this pull request as ready for review April 28, 2026 07:34
@daijh
Copy link
Copy Markdown
Contributor Author

daijh commented Apr 28, 2026

@apsonawane @kunal-vaishnavi PTAL

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.

Comment thread src/python/py/models/builders/qwen.py Outdated
@kunal-vaishnavi kunal-vaishnavi dismissed their stale review April 28, 2026 18:46

There are some styling changes that can be made.

Comment thread src/python/py/models/builders/qwen.py Outdated
@apsonawane
Copy link
Copy Markdown
Contributor

apsonawane commented Apr 29, 2026

Lgtm after Kunal's and Copilot changes are done

@daijh daijh force-pushed the qwen3.5-dedup-position_ids branch from c25aff1 to e98cb96 Compare April 29, 2026 04:47
@daijh
Copy link
Copy Markdown
Contributor Author

daijh commented Apr 29, 2026

Thanks.
Updated the PRs to address all the comments.

@daijh
Copy link
Copy Markdown
Contributor Author

daijh commented Apr 29, 2026

When I run lintrunner -a, it's touching unrelated code.
Should I include these lint fixes in this PR, or leave them out for now?

diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py
index 4e47e5c5..ffabd1ee 100644
--- a/src/python/py/models/builders/qwen.py
+++ b/src/python/py/models/builders/qwen.py
@@ -449,9 +449,7 @@ class Qwen25VLTextModel(Model):
 
         seq_len_node = f"{basename}/SeqLen/Gather"
         seq_len_out = f"{seq_len_node}/output_0"
-        self.make_gather(
-            seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/1"], ir.DataType.INT64, [], 0
-        )
+        self.make_gather(seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/1"], ir.DataType.INT64, [], 0)
 
         # Calculate Total Tokens = B * S
         mul_len_node = f"{basename}/TotalLen/Mul"
@@ -463,7 +461,10 @@ class Qwen25VLTextModel(Model):
         range_node = f"{basename}/Range"
         range_out = f"{range_node}/output_0"
         self.make_range(
-            range_node, ["/model/constants/INT64/0", mul_len_out, "/model/constants/INT64/1"], ir.DataType.INT64, ["total_token_count"]
+            range_node,
+            ["/model/constants/INT64/0", mul_len_out, "/model/constants/INT64/1"],
+            ir.DataType.INT64,
+            ["total_token_count"],
         )
         range_out = f"{range_node}/output_0"
 
@@ -1801,7 +1802,10 @@ class Qwen35TextModel(Model):
 
         a_plus_dt_name = f"{basename}/decay/Add"
         self.make_add(
-            a_plus_dt_name, [f"{a_cast_name}/output_0", dt_bias_init], ir.DataType.FLOAT, ["batch_size", "sequence_length", n_kv]
+            a_plus_dt_name,
+            [f"{a_cast_name}/output_0", dt_bias_init],
+            ir.DataType.FLOAT,
+            ["batch_size", "sequence_length", n_kv],
         )
         a_plus_dt_output = f"{a_plus_dt_name}/output_0"
 
@@ -1810,7 +1814,9 @@ class Qwen35TextModel(Model):
         softplus_output = f"{softplus_name}/output_0"
 
         g_fp32_name = f"{basename}/decay/Mul"
-        self.make_mul(g_fp32_name, [neg_exp_a_name, softplus_output], ir.DataType.FLOAT, ["batch_size", "sequence_length", n_kv])
+        self.make_mul(
+            g_fp32_name, [neg_exp_a_name, softplus_output], ir.DataType.FLOAT, ["batch_size", "sequence_length", n_kv]
+        )
         g_fp32_output = f"{g_fp32_name}/output_0"
 
         # Cast decay back to io_dtype for the kernel
@@ -1900,8 +1906,13 @@ class Qwen35TextModel(Model):
         self.make_mul(sq_name, [input_name, input_name], self.io_dtype, full_shape)
 
         sum_name = f"{basename}/SumSq/ReduceSum"
-        self.make_reduce_sum(sum_name, [f"{sq_name}/output_0", "/model/constants/INT64/[-1]"],
-                             self.io_dtype, reduced_shape, keepdims=True)
+        self.make_reduce_sum(
+            sum_name,
+            [f"{sq_name}/output_0", "/model/constants/INT64/[-1]"],
+            self.io_dtype,
+            reduced_shape,
+            keepdims=True,
+        )
 
         # sum(x^2) + eps
         eps_name = self._get_shared_l2_eps()
@@ -1990,12 +2001,17 @@ class Qwen35TextModel(Model):
         # output = norm * silu(z) in fp32
         gated_fp32_name = f"{basename}/gated_fp32/Mul"
         self.make_mul(
-            gated_fp32_name, [f"{norm_cast_name}/output_0", z_silu_output], ir.DataType.FLOAT, ["batch_size", "sequence_length", v_dim]
+            gated_fp32_name,
+            [f"{norm_cast_name}/output_0", z_silu_output],
+            ir.DataType.FLOAT,
+            ["batch_size", "sequence_length", v_dim],
         )
 
         # Cast back to io_dtype
         gated_name = f"{basename}/gated/Cast"
-        self.make_cast(gated_name, f"{gated_fp32_name}/output_0", self.io_dtype, ["batch_size", "sequence_length", v_dim])
+        self.make_cast(
+            gated_name, f"{gated_fp32_name}/output_0", self.io_dtype, ["batch_size", "sequence_length", v_dim]
+        )
         gated_output = f"{gated_name}/output_0"
 
         return gated_output

daijh added 3 commits April 29, 2026 13:08
Extract _make_synthetic_position_ids() to derive batch_size and
sequence_length from the position_ids model input [3, B, S] instead of
Shape on intermediate Q/K tensors. Use a fixed basename so make_node
dedup creates the subgraph once and reuses it across all layers and Q/K
calls.

- Reduces ONNX graph size by eliminating per-layer duplicates

- Eliminates INT64 Mul by using Reshape([3, -1]) to infer B*S
  implicitly, avoiding CPU fallback on WebGPU (WebGPU Mul does not
  support INT64)
@daijh daijh force-pushed the qwen3.5-dedup-position_ids branch from e98cb96 to c298862 Compare April 29, 2026 05:09
@daijh
Copy link
Copy Markdown
Contributor Author

daijh commented Apr 29, 2026

When I run lintrunner -a, it's touching unrelated code.
Should I include these lint fixes in this PR, or leave them out for now?

OKAY. Rebasing against main cleared the lint errors.

@daijh daijh requested a review from a team as a code owner May 6, 2026 01:27
@daijh
Copy link
Copy Markdown
Contributor Author

daijh commented May 6, 2026

@kunal-vaishnavi
It looks like the previous CI failures were due to infrastructure issues.
Could you please re-trigger the checks and merge the PR if they pass?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants