Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
from .rewrite_index_put_pass import RewriteIndexPutPass # noqa
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
from .rewrite_matmul import RewriteMatmulPass # noqa
from .rewrite_pad import RewritePadPass # noqa
from .rewrite_upsample import RewriteUpsamplePass # noqa
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
RewriteIndexPutPass,
RewriteLeLtToGeGtPass,
RewriteMatmulPass,
RewritePadPass,
RewriteUpsamplePass,
ScalarsToAttributePass,
SizeAdjustInputPass,
Expand Down Expand Up @@ -370,6 +371,7 @@ def _tosa_pipeline(
RewriteUpsamplePass(),
RewriteConvPass(exported_program),
RewriteMatmulPass(),
RewritePadPass(),
]
)

Expand Down
67 changes: 67 additions & 0 deletions backends/arm/_passes/rewrite_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class RewritePadPass(ArmPass):
"""Rewrite constant_pad_nd operator to TOSA Pad operator with constant
mode.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
targeted_ops = {
exir_ops.edge.aten.constant_pad_nd.default,
}

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)

if len(args) == 3:
input_tensor, pad, value = args
else:
input_tensor, pad = args
value = 0

output_dtype = meta["val"].dtype
if output_dtype in (torch.int8, torch.int16):
input_qparams = meta.data.get("input_qparams", {})
if len(input_qparams) == 0:
raise ValueError(
f"No input quantization parameters found in metadata for constant_pad_nd with output dtype {output_dtype}"
)
value = input_qparams[0].quantize_value(value).item()

# Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form
# (padding_left, padding_right); to pad the last two dimensions, the pad has the form
# (padding_left, padding_right, padding_top, padding_bottom), and so on. We want to reverse the padding
# so that we get (N_before, N_after, C_before, C_after, H_before, H_after, W_before, W_after) for a 4D
# input tensor.
pad_pairs = [[pad[i], pad[i + 1]] for i in range(0, len(pad), 2)]
input_pad = []
for pair in reversed(pad_pairs):
input_pad.extend(pair)
input_rank = len(input_tensor.data.shape)
# Place spatial dimensions last and pad non-spatial dimensions with 0 padding
shape = [0] * ((input_rank * 2 - len(pad))) + input_pad

pad_shape = super().call_shape_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, (shape,), {}, meta, True
)

return super().call_operator(
exir_ops.backend.tosa.PAD.default,
(input_tensor, pad_shape),
{"value": value},
meta,
True,
)
10 changes: 9 additions & 1 deletion backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,16 @@ def _propagate_dim_order_to_shape_args(self, node: torch.fx.Node) -> None:
# shape nodes depending on the order of user traversal.
old_dim_order = arg.meta.get("tosa_dim_order", None) is not None
dim_order = node.meta["tosa_dim_order"]
# The shape node may have a different rank than the dim_order being propagated from its users
if len(dim_order) != len(arg.meta["val"]):
dim_order = tuple(range(len(arg.meta["val"])))
# For pad shape nodes, the rank is always 2x of the input tensor rank, and the dim order needs to be adjusted accordingly.
# For other shape nodes, we assume the dim order is the same as the order of dimensions in the shape.
if node.target == exir_ops.backend.tosa.PAD.default:
dim_order = tuple(
i for axis in dim_order for i in (2 * axis, 2 * axis + 1)
)
else:
dim_order = tuple(range(len(arg.meta["val"])))
if old_dim_order and arg.meta["tosa_dim_order"] != dim_order:
raise RuntimeError(
f"Conflicting dim orders {arg.meta['tosa_dim_order']} and {dim_order} for shape node {arg.name}"
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
op_ceil,
op_clamp,
op_cond_if,
op_constant_pad_nd,
op_cos,
op_eq,
op_erf,
Expand Down Expand Up @@ -57,6 +56,7 @@
op_tosa_depthwise_conv2d,
op_tosa_gather,
op_tosa_matmul,
op_tosa_pad,
op_tosa_rescale,
op_tosa_resize,
op_tosa_scatter,
Expand Down
113 changes: 0 additions & 113 deletions backends/arm/operators/op_constant_pad_nd.py

This file was deleted.

53 changes: 53 additions & 0 deletions backends/arm/operators/op_tosa_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List

import torch

import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa.mapping import TosaArg


@register_node_visitor
class TosaPadVisitor(NodeVisitor):
target = "tosa.PAD.default"

tosa_specs = NodeVisitor.tosa_specs

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
pad_const = tosa_graph.addConst(
[1],
output.dtype,
[node.kwargs.get("value", 0)],
name=node.name + "_padding_value",
)

attr = ts.TosaSerializerAttribute()
attr.PadAttribute()

self._serialize_operator(
node,
tosa_graph,
ts.Op.PAD,
[
inputs[0].name,
inputs[1].name,
pad_const.name,
],
[output.name],
attr,
)
1 change: 1 addition & 0 deletions backends/arm/tosa/dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
depthwise_conv2d,
gather,
matmul,
pad,
rescale,
resize,
scatter,
Expand Down
57 changes: 57 additions & 0 deletions backends/arm/tosa/dialect/ops/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch

from executorch.backends.arm.tosa.dialect.lib import TosaValueError
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op

from executorch.backends.arm.tosa.specification import (
get_context_spec,
TosaSpecification,
)


@register_fake_tosa_op(
"PAD(Tensor input1, SymInt[] padding, *, Scalar value) -> Tensor", # schema
(
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
), # target TOSA specifications
)
def PAD(a: torch.Tensor, padding: List[int | torch.SymInt], *, value):
tosa_spec = get_context_spec()

supported_dtypes = {torch.bool}
if tosa_spec.support_integer():
supported_dtypes.update({torch.int8, torch.int16, torch.int32})
if tosa_spec.support_float():
supported_dtypes.update({torch.float16, torch.float32})
if tosa_spec.support_extension("bf16"):
supported_dtypes.add(torch.bfloat16)
if a.dtype not in supported_dtypes:
raise TosaValueError(
f"Input tensor dtype {a.dtype} is not supported by the target TOSA specification."
f" Supported dtypes are: {supported_dtypes}",
op="PAD",
)

if len(padding) != 2 * len(a.shape):
raise TosaValueError(
f"Padding length {len(padding)} is not compatible with input rank {len(a.shape)}",
op="PAD",
)

# new shape:
new_shape: List[int | torch.SymInt] = []
for i, d in enumerate(a.shape):
pad_before = padding[i * 2]
pad_after = padding[i * 2 + 1]
new_shape.append(pad_before + d + pad_after)

# return a new tensor with the new shape
return torch.empty(size=new_shape, dtype=a.dtype)
Loading