diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 0811261f9..06933faaa 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -439,17 +439,17 @@ def apply_mrope_rotation(self, layer_id, q_or_k_path, q_or_k_shape, dyn_cos, dyn shape_node = f"{basename}/Shape" self.make_shape(shape_node, q_or_k_path, [3]) - # Extract B and S + # Extract B and S (scalar Gather indices → scalar outputs) batch_size_node = f"{basename}/BatchSize/Gather" batch_size_out = f"{batch_size_node}/output_0" self.make_gather( - batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, [], 0 + batch_size_node, [f"{shape_node}/output_0", "/model/constants/INT64/0"], ir.DataType.INT64, [], 0 ) seq_len_node = f"{basename}/SeqLen/Gather" seq_len_out = f"{seq_len_node}/output_0" self.make_gather( - seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [], 0 + seq_len_node, [f"{shape_node}/output_0", "/model/constants/INT64/1"], ir.DataType.INT64, [], 0 ) # Calculate Total Tokens = B * S @@ -1372,6 +1372,84 @@ def interleave(suffix, cache_name): return interleave("cos", cos_cache), interleave("sin", sin_cache) + def _make_synthetic_position_ids(self): + """Build synthetic position_ids [B, S] with values 0 .. B*S-1. + + Derives B and S from the ``position_ids`` model input ``[3, B, S]`` + instead of using Shape on intermediate Q/K tensors. This avoids a + data-dependency on Q/K computation. + + B*S is obtained by reshaping position_ids to ``[3, -1]`` and reading + the inferred dimension from the shape. This lets the runtime compute + the product implicitly (Reshape is metadata-only) and avoids an + explicit INT64 Mul that would fall back to CPU on WebGPU. + + Uses a fixed basename so ``make_node`` dedup ensures nodes are + created once and reused across all layers and Q/K calls. + """ + basename = "/model/attn/synthetic_pos_ids" + pos_ids_input = self.input_names["position_ids"] + + # Shape(position_ids) → [3, B, S] + shape_name = f"{basename}/Shape" + self.make_shape(shape_name, root_input=pos_ids_input, shape=[3]) + + # Slice shape[1:3] → [B, S] (used as reshape target at the end) + bs_shape_name = f"{basename}/bs_shape/Slice" + self.make_slice( + bs_shape_name, + inputs=[ + f"{shape_name}/output_0", + "/model/constants/INT64/[1]", + "/model/constants/INT64/[3]", + "/model/constants/INT64/[0]", + ], + dtype=ir.DataType.INT64, + shape=[2], + ) + + # Reshape position_ids [3, B, S] → [3, -1] to get B*S implicitly + flat_name = f"{basename}/flat/Reshape" + self.make_reshape( + flat_name, + inputs=[pos_ids_input, "/model/constants/INT64/[3, -1]"], + dtype=ir.DataType.INT64, + shape=[3, "batch_seq"], + ) + + # Shape([3, B*S]) → [3, B*S], Gather scalar index 1 → scalar B*S + shape2_name = f"{basename}/Shape2" + self.make_shape(shape2_name, root_input=f"{flat_name}/output_0", shape=[2]) + + total_name = f"{basename}/total/Gather" + self.make_gather( + total_name, + inputs=[f"{shape2_name}/output_0", "/model/constants/INT64/1"], + dtype=ir.DataType.INT64, + shape=[], + axis=0, + ) + + # Range(0, B*S, 1) + range_name = f"{basename}/range/Range" + self.make_range( + range_name, + inputs=["/model/constants/INT64/0", f"{total_name}/output_0", "/model/constants/INT64/1"], + dtype=ir.DataType.INT64, + shape=["batch_seq"], + ) + + # Reshape to [B, S] + pos_ids_name = f"{basename}/Reshape" + self.make_reshape( + pos_ids_name, + inputs=[f"{range_name}/output_0", f"{bs_shape_name}/output_0"], + dtype=ir.DataType.INT64, + shape=["batch_size", "sequence_length"], + ) + + return f"{pos_ids_name}/output_0" + def _apply_mrope_rotation(self, layer_id, qk_path, qk_shape, dyn_cos, dyn_sin, num_heads, basename): """Apply mRoPE via com.microsoft.RotaryEmbedding (4-input variant). @@ -1419,45 +1497,10 @@ def _apply_mrope_rotation(self, layer_id, qk_path, qk_shape, dyn_cos, dyn_sin, n rope_sin = f"{sin_cast_name}/output_0" # --- Build synthetic position_ids [B, S] = Range(0, B*S).reshape(B, S) --- - # Shape(Q/K input) → [B, S, N*H], Gather dim 0 and 1 → B, S - shape_name = f"{basename}/qk_shape/Shape" - self.make_shape(shape_name, qk_path, [3]) - - batch_name = f"{basename}/batch/Gather" - self.make_gather(batch_name, [f"{shape_name}/output_0", "/model/constants/INT64/[0]"], ir.DataType.INT64, [1], axis=0) - - seq_name = f"{basename}/seq/Gather" - self.make_gather(seq_name, [f"{shape_name}/output_0", "/model/constants/INT64/[1]"], ir.DataType.INT64, [1], axis=0) - - total_name = f"{basename}/total/Mul" - self.make_mul(total_name, [f"{batch_name}/output_0", f"{seq_name}/output_0"], ir.DataType.INT64, [1]) - - range_name = f"{basename}/range/Range" - self.make_range( - range_name, - ["/model/constants/INT64/[0]", f"{total_name}/output_0", "/model/constants/INT64/[1]"], - ir.DataType.INT64, - ["batch_seq"], - ) - - # Reshape to [B, S] - bs_shape_name = f"{basename}/bs_shape/Concat" - self.make_concat( - bs_shape_name, - [f"{batch_name}/output_0", f"{seq_name}/output_0"], - ir.DataType.INT64, - [2], - axis=0, - ) - - pos_ids_name = f"{basename}/pos_ids/Reshape" - self.make_reshape( - pos_ids_name, - [f"{range_name}/output_0", f"{bs_shape_name}/output_0"], - ir.DataType.INT64, - ["batch_size", "sequence_length"], - ) - pos_ids = f"{pos_ids_name}/output_0" + # Derive B and S from the position_ids input [3, B, S] instead of + # using Shape on intermediate Q/K tensors. Shared across all layers + # and Q/K calls via make_node dedup. + pos_ids = self._make_synthetic_position_ids() # --- Reshape Q/K to [B, N, S, H] for com.microsoft.RotaryEmbedding --- head_size = qk_shape[-1] // num_heads if isinstance(qk_shape[-1], int) else self.head_size