Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 85 additions & 42 deletions src/python/py/models/builders/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,17 +439,17 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn
shape_node = f"{basename}/Shape"
self.make_shape(shape_node, q_or_k_path, [3])

# Extract B and S
# Extract B and S (scalar Gather indices → scalar outputs)
batch_size_node = f"{basename}/BatchSize/Gather"
batch_size_out = f"{batch_size_node}/output_0"
self.make_gather(
batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, [], 0
batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/0"], ir.DataType.INT64, [], 0
)

seq_len_node = f"{basename}/SeqLen/Gather"
seq_len_out = f"{seq_len_node}/output_0"
self.make_gather(
seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [], 0
seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/1"], ir.DataType.INT64, [], 0
)

# Calculate Total Tokens = B * S
Expand Down Expand Up @@ -1382,6 +1382,84 @@ def interleave(suffix, cache_name):

return interleave("cos", cos_cache), interleave("sin", sin_cache)

def _make_synthetic_position_ids(self):
"""Build synthetic position_ids [B, S] with values 0 .. B*S-1.

Derives B and S from the ``position_ids`` model input ``[3, B, S]``
instead of using Shape on intermediate Q/K tensors. This avoids a
data-dependency on Q/K computation.

B*S is obtained by reshaping position_ids to ``[3, -1]`` and reading
the inferred dimension from the shape. This lets the runtime compute
the product implicitly (Reshape is metadata-only) and avoids an
explicit INT64 Mul that would fall back to CPU on WebGPU.

Uses a fixed basename so ``make_node`` dedup ensures nodes are
created once and reused across all layers and Q/K calls.
"""
basename = "/model/attn/synthetic_pos_ids"
pos_ids_input = self.input_names["position_ids"]

# Shape(position_ids) → [3, B, S]
shape_name = f"{basename}/Shape"
self.make_shape(shape_name, root_input=pos_ids_input, shape=[3])

# Slice shape[1:3] → [B, S] (used as reshape target at the end)
bs_shape_name = f"{basename}/bs_shape/Slice"
self.make_slice(
bs_shape_name,
inputs=[
f"{shape_name}/output_0",
"/model/constants/INT64/[1]",
"/model/constants/INT64/[3]",
"/model/constants/INT64/[0]",
],
dtype=ir.DataType.INT64,
shape=[2],
)

# Reshape position_ids [3, B, S] → [3, -1] to get B*S implicitly
flat_name = f"{basename}/flat/Reshape"
self.make_reshape(
flat_name,
inputs=[pos_ids_input, "/model/constants/INT64/[3, -1]"],
dtype=ir.DataType.INT64,
shape=[3, "batch_seq"],
)

# Shape([3, B*S]) → [3, B*S], Gather scalar index 1 → scalar B*S
shape2_name = f"{basename}/Shape2"
self.make_shape(shape2_name, root_input=f"{flat_name}/output_0", shape=[2])

total_name = f"{basename}/total/Gather"
self.make_gather(
total_name,
inputs=[f"{shape2_name}/output_0", "/model/constants/INT64/1"],
dtype=ir.DataType.INT64,
shape=[],
axis=0,
)

# Range(0, B*S, 1)
range_name = f"{basename}/range/Range"
self.make_range(
range_name,
inputs=["/model/constants/INT64/0", f"{total_name}/output_0", "/model/constants/INT64/1"],
dtype=ir.DataType.INT64,
shape=["batch_seq"],
)

# Reshape to [B, S]
pos_ids_name = f"{basename}/Reshape"
self.make_reshape(
pos_ids_name,
inputs=[f"{range_name}/output_0", f"{bs_shape_name}/output_0"],
dtype=ir.DataType.INT64,
shape=["batch_size", "sequence_length"],
)

return f"{pos_ids_name}/output_0"

def _apply_mrope_rotation(self, layer_id, qk_path, qk_shape, dyn_cos, dyn_sin, num_heads, basename):
"""Apply mRoPE via com.microsoft.RotaryEmbedding (4-input variant).

Expand Down Expand Up @@ -1429,45 +1507,10 @@ def _apply_mrope_rotation(self, layer_id, qk_path, qk_shape, dyn_cos, dyn_sin, n
rope_sin = f"{sin_cast_name}/output_0"

# --- Build synthetic position_ids [B, S] = Range(0, B*S).reshape(B, S) ---
# Shape(Q/K input) → [B, S, N*H], Gather dim 0 and 1 → B, S
shape_name = f"{basename}/qk_shape/Shape"
self.make_shape(shape_name, qk_path, [3])

batch_name = f"{basename}/batch/Gather"
self.make_gather(batch_name, [f"{shape_name}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, [1], axis=0)

seq_name = f"{basename}/seq/Gather"
self.make_gather(seq_name, [f"{shape_name}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [1], axis=0)

total_name = f"{basename}/total/Mul"
self.make_mul(total_name, [f"{batch_name}/output_0", f"{seq_name}/output_0"], ir.DataType.INT64, [1])

range_name = f"{basename}/range/Range"
self.make_range(
range_name,
["/model/constants/INT64/[0]", f"{total_name}/output_0", "/model/constants/INT64/[1]"],
ir.DataType.INT64,
["batch_seq"],
)

# Reshape to [B, S]
bs_shape_name = f"{basename}/bs_shape/Concat"
self.make_concat(
bs_shape_name,
[f"{batch_name}/output_0", f"{seq_name}/output_0"],
ir.DataType.INT64,
[2],
axis=0,
)

pos_ids_name = f"{basename}/pos_ids/Reshape"
self.make_reshape(
pos_ids_name,
[f"{range_name}/output_0", f"{bs_shape_name}/output_0"],
ir.DataType.INT64,
["batch_size", "sequence_length"],
)
pos_ids = f"{pos_ids_name}/output_0"
# Derive B and S from the position_ids input [3, B, S] instead of
# using Shape on intermediate Q/K tensors. Shared across all layers
# and Q/K calls via make_node dedup.
pos_ids = self._make_synthetic_position_ids()

# --- Reshape Q/K to [B, N, S, H] for com.microsoft.RotaryEmbedding ---
head_size = qk_shape[-1] // num_heads if isinstance(qk_shape[-1], int) else self.head_size
Expand Down