diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 12efa0822bd..094bc8d46b7 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -135,6 +135,7 @@ from .rewrite_index_put_pass import RewriteIndexPutPass # noqa from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa +from .rewrite_pad import RewritePadPass # noqa from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa from .size_adjust_input_pass import SizeAdjustInputPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 725bc77a734..5d1fa00b02b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -120,6 +120,7 @@ RewriteIndexPutPass, RewriteLeLtToGeGtPass, RewriteMatmulPass, + RewritePadPass, RewriteUpsamplePass, ScalarsToAttributePass, SizeAdjustInputPass, @@ -372,6 +373,7 @@ def _tosa_pipeline( RewriteUpsamplePass(), RewriteConvPass(exported_program), RewriteMatmulPass(), + RewritePadPass(), ] ) diff --git a/backends/arm/_passes/rewrite_pad.py b/backends/arm/_passes/rewrite_pad.py new file mode 100644 index 00000000000..14899919e3a --- /dev/null +++ b/backends/arm/_passes/rewrite_pad.py @@ -0,0 +1,67 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class RewritePadPass(ArmPass): + """Rewrite constant_pad_nd operator to TOSA Pad operator with constant + mode. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { + exir_ops.edge.aten.constant_pad_nd.default, + } + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta) + + if len(args) == 3: + input_tensor, pad, value = args + else: + input_tensor, pad = args + value = 0 + + output_dtype = meta["val"].dtype + if output_dtype in (torch.int8, torch.int16): + input_qparams = meta.data.get("input_qparams", {}) + if len(input_qparams) == 0: + raise ValueError( + f"No input quantization parameters found in metadata for constant_pad_nd with output dtype {output_dtype}" + ) + value = input_qparams[0].quantize_value(value).item() + + # Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form + # (padding_left, padding_right); to pad the last two dimensions, the pad has the form + # (padding_left, padding_right, padding_top, padding_bottom), and so on. We want to reverse the padding + # so that we get (N_before, N_after, C_before, C_after, H_before, H_after, W_before, W_after) for a 4D + # input tensor. + pad_pairs = [[pad[i], pad[i + 1]] for i in range(0, len(pad), 2)] + input_pad = [] + for pair in reversed(pad_pairs): + input_pad.extend(pair) + input_rank = len(input_tensor.data.shape) + # Place spatial dimensions last and pad non-spatial dimensions with 0 padding + shape = [0] * ((input_rank * 2 - len(pad))) + input_pad + + pad_shape = super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, (shape,), {}, meta, True + ) + + return super().call_operator( + exir_ops.backend.tosa.PAD.default, + (input_tensor, pad_shape), + {"value": value}, + meta, + True, + ) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 6b3cbc0fe0a..1c1c12ae816 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -413,8 +413,16 @@ def _propagate_dim_order_to_shape_args(self, node: torch.fx.Node) -> None: # shape nodes depending on the order of user traversal. old_dim_order = arg.meta.get("tosa_dim_order", None) is not None dim_order = node.meta["tosa_dim_order"] + # The shape node may have a different rank than the dim_order being propagated from its users if len(dim_order) != len(arg.meta["val"]): - dim_order = tuple(range(len(arg.meta["val"]))) + # For pad shape nodes, the rank is always 2x of the input tensor rank, and the dim order needs to be adjusted accordingly. + # For other shape nodes, we assume the dim order is the same as the order of dimensions in the shape. + if node.target == exir_ops.backend.tosa.PAD.default: + dim_order = tuple( + i for axis in dim_order for i in (2 * axis, 2 * axis + 1) + ) + else: + dim_order = tuple(range(len(arg.meta["val"]))) if old_dim_order and arg.meta["tosa_dim_order"] != dim_order: raise RuntimeError( f"Conflicting dim orders {arg.meta['tosa_dim_order']} and {dim_order} for shape node {arg.name}" diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 4069ff1dc74..c8ba7c844a1 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -23,7 +23,6 @@ op_ceil, op_clamp, op_cond_if, - op_constant_pad_nd, op_cos, op_eq, op_erf, @@ -57,6 +56,7 @@ op_tosa_depthwise_conv2d, op_tosa_gather, op_tosa_matmul, + op_tosa_pad, op_tosa_rescale, op_tosa_resize, op_tosa_scatter, diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py deleted file mode 100644 index 57c44d8f7cf..00000000000 --- a/backends/arm/operators/op_constant_pad_nd.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Any, List - -import torch - -import tosa_serializer as ts - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg - - -@register_node_visitor -class ConstantPadNDVisitor(NodeVisitor): - - target = "aten.constant_pad_nd.default" - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 3) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ - ts.DType.INT8, - ts.DType.INT16, - ts.DType.INT32, - ts.DType.FP16, - ts.DType.FP32, - ts.DType.BF16, - ts.DType.BOOL, - ], - self.tosa_spec, - ) - - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - qargs = input_qparams[0] - pad_const_val = qargs.quantize_value(inputs[2].number).item() - pad_const_dtype = ts.DType.INT8 - elif inputs[0].dtype == ts.DType.INT16: - input_qparams = get_input_qparams(node) - qargs = input_qparams[0] - pad_const_val = qargs.quantize_value(inputs[2].number).item() - pad_const_dtype = ts.DType.INT16 - else: - pad_const_val = inputs[2].number - pad_const_dtype = inputs[0].dtype - - rank = len(output.shape) - # Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form - # (padding_left, padding_right); to pad the last two dimensions, the pad has the form - # (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding - # values are in the reverse order. So, firstly we need to reverse the input padding parameters. - input_pad = sum( - [ - [inputs[1].special[i], inputs[1].special[i + 1]] - for i in range(0, len(inputs[1].special), 2) - ][::-1], - [], - ) - # Then, add dummy zeros to make sure that both input_pad and output_pad has the same size. - input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad - # For PyTorch NCHW format, dim order is [0,...,rank-1] - input_dim_order = list(range(rank)) - output_pad = [0] * rank * 2 - - # Map input padding parameters into output padding parameters. TOSA is NHWC format. - for input_dim_idx, input_dim in enumerate(input_dim_order): - output_dim_idx = output.dim_order.index(input_dim) - output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[ - input_dim_idx * 2 : (input_dim_idx + 1) * 2 - ] - - padding = tosa_graph.addConst( - shape=[len(output_pad)], dtype=ts.DType.SHAPE, vals=output_pad - ) - - pad_const = tosa_graph.addConst( - shape=[1], dtype=pad_const_dtype, vals=[pad_const_val] - ) - - attr = ts.TosaSerializerAttribute() - attr.PadAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.PAD, - [inputs[0].name, padding.name, pad_const.name], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_tosa_pad.py b/backends/arm/operators/op_tosa_pad.py new file mode 100644 index 00000000000..eaba29c583a --- /dev/null +++ b/backends/arm/operators/op_tosa_pad.py @@ -0,0 +1,53 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class TosaPadVisitor(NodeVisitor): + target = "tosa.PAD.default" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + pad_const = tosa_graph.addConst( + [1], + output.dtype, + [node.kwargs.get("value", 0)], + name=node.name + "_padding_value", + ) + + attr = ts.TosaSerializerAttribute() + attr.PadAttribute() + + self._serialize_operator( + node, + tosa_graph, + ts.Op.PAD, + [ + inputs[0].name, + inputs[1].name, + pad_const.name, + ], + [output.name], + attr, + ) diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index da9ae621509..988928c46f9 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -9,6 +9,7 @@ depthwise_conv2d, gather, matmul, + pad, rescale, resize, scatter, diff --git a/backends/arm/tosa/dialect/ops/pad.py b/backends/arm/tosa/dialect/ops/pad.py new file mode 100644 index 00000000000..db2cab6fcfc --- /dev/null +++ b/backends/arm/tosa/dialect/ops/pad.py @@ -0,0 +1,57 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch + +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + + +@register_fake_tosa_op( + "PAD(Tensor input1, SymInt[] padding, *, Scalar value) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ), # target TOSA specifications +) +def PAD(a: torch.Tensor, padding: List[int | torch.SymInt], *, value): + tosa_spec = get_context_spec() + + supported_dtypes = {torch.bool} + if tosa_spec.support_integer(): + supported_dtypes.update({torch.int8, torch.int16, torch.int32}) + if tosa_spec.support_float(): + supported_dtypes.update({torch.float16, torch.float32}) + if tosa_spec.support_extension("bf16"): + supported_dtypes.add(torch.bfloat16) + if a.dtype not in supported_dtypes: + raise TosaValueError( + f"Input tensor dtype {a.dtype} is not supported by the target TOSA specification." + f" Supported dtypes are: {supported_dtypes}", + op="PAD", + ) + + if len(padding) != 2 * len(a.shape): + raise TosaValueError( + f"Padding length {len(padding)} is not compatible with input rank {len(a.shape)}", + op="PAD", + ) + + # new shape: + new_shape: List[int | torch.SymInt] = [] + for i, d in enumerate(a.shape): + pad_before = padding[i * 2] + pad_after = padding[i * 2 + 1] + new_shape.append(pad_before + d + pad_after) + + # return a new tensor with the new shape + return torch.empty(size=new_shape, dtype=a.dtype)