Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
760 changes: 760 additions & 0 deletions backends/nvidia/tensorrt/converter_utils.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions backends/nvidia/tensorrt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
# LICENSE file in the root directory of this source tree.

"""TensorRT converters for ExecuTorch operations."""

# Import converters to trigger registration via @converter decorator
from executorch.backends.nvidia.tensorrt.converters import add # noqa: F401
150 changes: 150 additions & 0 deletions backends/nvidia/tensorrt/converters/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Converter for element-wise addition operations."""

from typing import Any, Dict, Optional

import tensorrt as trt
import torch
from executorch.backends.nvidia.tensorrt.converter_registry import converter
from executorch.backends.nvidia.tensorrt.converter_utils import (
broadcast_tensors,
get_node_dtype,
get_trt_tensor,
promote_and_cast_tensors,
set_layer_name,
)


def _get_input_ndim(arg: Any, input_map: Dict[torch.fx.Node, Any]) -> int:
"""Get the number of dimensions for an elementwise input argument.

Uses node metadata when available for reliability during network building.

Args:
arg: Input argument (either torch.fx.Node or scalar).
input_map: Mapping from FX nodes to TensorRT tensors.

Returns:
Number of dimensions (0 for scalars).
"""
if isinstance(arg, torch.fx.Node):
# Try to get ndim from node metadata first (most reliable)
if "val" in arg.meta and hasattr(arg.meta["val"], "shape"):
return len(arg.meta["val"].shape)
# Fall back to TRT tensor shape
if arg in input_map:
trt_tensor = input_map[arg]
shape = trt_tensor.shape
if shape is not None:
return len(shape)
# For scalars, return 0 (will be broadcast)
return 0


def _get_elementwise_input(
network: trt.INetworkDefinition,
input_map: Dict[torch.fx.Node, Any],
arg: Any,
name: str,
dtype: Optional[torch.dtype],
) -> trt.ITensor:
"""Get TensorRT tensor for an elementwise operation input.

Args:
network: TensorRT network definition.
input_map: Mapping from FX nodes to TensorRT tensors.
arg: Input argument (either torch.fx.Node or scalar value).
name: Name for the constant tensor if created.
dtype: Data type for scalar conversion.

Returns:
TensorRT tensor for the input.

Raises:
ValueError: If arg is a Node but not found in input_map.
"""
if isinstance(arg, torch.fx.Node):
if arg not in input_map:
raise ValueError(
f"Input node '{arg.name}' not found in input_map. "
f"Available nodes: {list(input_map.keys())}"
)
return input_map[arg]
return get_trt_tensor(network, arg, name, dtype)


@converter("aten.add.Tensor", "aten.add_.Tensor")
def convert_add(
node: torch.fx.Node,
network: trt.INetworkDefinition,
input_map: Dict[torch.fx.Node, Any],
ctx: Any = None,
) -> trt.ITensor:
"""Convert aten.add.Tensor and aten.add_.Tensor to TensorRT ElementWise SUM.

Handles tensor + tensor, tensor + scalar, and scalar + tensor cases.
The alpha parameter (x + alpha * y) is validated to be 1.
Note: In-place variant (add_) is handled identically since TensorRT doesn't
have in-place operations.

Args:
node: FX node representing the add operation.
network: TensorRT network definition.
input_map: Mapping from FX nodes to TensorRT tensors.
ctx: Optional conversion context.

Returns:
TensorRT tensor representing the sum.

Raises:
ValueError: If alpha != 1 or if required inputs are missing.
"""
# Validate args
if len(node.args) < 2:
raise ValueError(
f"aten.add requires at least 2 arguments, got {len(node.args)}"
)

lhs_arg = node.args[0]
rhs_arg = node.args[1]

alpha = node.args[2] if len(node.args) > 2 else node.kwargs.get("alpha", 1)
if alpha != 1:
raise ValueError(
f"aten.add.Tensor with alpha != 1 is not supported, got alpha={alpha}"
)

dtype = get_node_dtype(node)

lhs = _get_elementwise_input(network, input_map, lhs_arg, "lhs", dtype)
rhs = _get_elementwise_input(network, input_map, rhs_arg, "rhs", dtype)

# Type promotion: ensure both operands have compatible types
lhs, rhs = promote_and_cast_tensors(network, lhs, rhs, f"add_{node.name}")

# Get target ndim from node metadata for reliability
lhs_ndim = _get_input_ndim(lhs_arg, input_map)
rhs_ndim = _get_input_ndim(rhs_arg, input_map)
target_ndim = max(lhs_ndim, rhs_ndim)

# Fall back to output shape from node metadata if we couldn't get input shapes
if target_ndim == 0 and "val" in node.meta and hasattr(node.meta["val"], "shape"):
target_ndim = len(node.meta["val"].shape)

# If still 0, both inputs are scalars - result is scalar (0-dim tensor in TRT is 1-dim)
if target_ndim == 0:
target_ndim = 1

lhs, rhs = broadcast_tensors(network, [lhs, rhs], target_ndim, f"add_{node.name}")

layer = network.add_elementwise(lhs, rhs, trt.ElementWiseOperation.SUM)
if layer is None:
raise RuntimeError(f"Failed to create elementwise SUM layer for {node.name}")
set_layer_name(layer, node, "add")

return layer.get_output(0)
2 changes: 2 additions & 0 deletions backends/nvidia/tensorrt/converters/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ def define_common_targets():
name = "converters",
srcs = [
"__init__.py",
"add.py",
],
visibility = ["PUBLIC"],
deps = [
"//executorch/backends/nvidia/tensorrt:converter_registry",
"//executorch/backends/nvidia/tensorrt:converter_utils",
],
)
12 changes: 12 additions & 0 deletions backends/nvidia/tensorrt/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,15 @@ def define_common_targets():
"//caffe2:torch",
],
)

runtime.python_library(
name = "converter_utils",
srcs = [
"converter_utils.py",
],
visibility = ["PUBLIC"],
deps = [
"//caffe2:torch",
"//deeplearning/trt/python:py_tensorrt",
],
)
34 changes: 34 additions & 0 deletions backends/nvidia/tensorrt/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load(":targets.bzl", "define_common_targets")
load("@fbcode_macros//build_defs:python_unittest_remote_gpu.bzl", "python_unittest_remote_gpu")

oncall("executorch")

define_common_targets()

# GPU-dependent tests: these import TensorRT which requires CUDA to initialize.
# python_unittest_remote_gpu routes them to GPU-equipped CI workers.

python_unittest_remote_gpu(
name = "test_converter_registry",
srcs = [
"test_converter_registry.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/nvidia/tensorrt:converter_registry",
"//executorch/backends/nvidia/tensorrt:converter_utils",
"//executorch/backends/nvidia/tensorrt/converters:converters",
],
)

python_unittest_remote_gpu(
name = "test_operator_support",
srcs = [
"test_operator_support.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/nvidia/tensorrt/partitioner:partitioner",
"//executorch/exir:lib",
],
)
7 changes: 7 additions & 0 deletions backends/nvidia/tensorrt/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for TensorRT backend."""
9 changes: 9 additions & 0 deletions backends/nvidia/tensorrt/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""
pass
77 changes: 77 additions & 0 deletions backends/nvidia/tensorrt/test/test_converter_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for TensorRT converter registry and converter utilities."""

import unittest


class ConverterRegistryTest(unittest.TestCase):
"""Tests for converter registry functionality."""

def test_registry_functions_exist(self) -> None:
from executorch.backends.nvidia.tensorrt.converter_registry import (
clear_registry,
get_registered_ops,
has_converter,
lookup_converter,
register_converter,
)

self.assertIsNotNone(has_converter)
self.assertIsNotNone(lookup_converter)
self.assertIsNotNone(register_converter)
self.assertIsNotNone(get_registered_ops)
self.assertIsNotNone(clear_registry)

def test_add_converter_registered(self) -> None:
from executorch.backends.nvidia.tensorrt.converter_registry import (
get_registered_ops,
has_converter,
lookup_converter,
)
from executorch.backends.nvidia.tensorrt.converters import add # noqa: F401

self.assertTrue(has_converter("aten.add.Tensor"))
self.assertIn("aten.add.Tensor", get_registered_ops())
self.assertIsNotNone(lookup_converter("aten.add.Tensor"))

def test_all_converters_registered(self) -> None:
"""Test that all converters are registered after importing converters."""
from executorch.backends.nvidia.tensorrt.converter_registry import (
get_registered_ops,
has_converter,
)
from executorch.backends.nvidia.tensorrt.converters import add # noqa: F401

expected_ops = [
"aten.add.Tensor",
]

for op in expected_ops:
self.assertTrue(has_converter(op), f"Missing converter for {op}")
self.assertIn(op, get_registered_ops())


class ConverterUtilsTest(unittest.TestCase):
"""Tests for converter utility functions."""

def test_converter_utils_functions_exist(self) -> None:
from executorch.backends.nvidia.tensorrt.converter_utils import (
broadcast_tensors,
get_node_dtype,
get_trt_tensor,
set_layer_name,
torch_dtype_to_trt,
trt_dtype_to_torch,
)

self.assertIsNotNone(torch_dtype_to_trt)
self.assertIsNotNone(trt_dtype_to_torch)
self.assertIsNotNone(get_trt_tensor)
self.assertIsNotNone(broadcast_tensors)
self.assertIsNotNone(get_node_dtype)
self.assertIsNotNone(set_layer_name)
44 changes: 44 additions & 0 deletions backends/nvidia/tensorrt/test/test_operator_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for TensorRT operator support functionality."""

import unittest

import torch


class OperatorSupportTest(unittest.TestCase):
"""Tests for TensorRTOperatorSupport functionality."""

def test_get_op_name_for_add(self) -> None:
"""Test that TensorRTOperatorSupport correctly identifies add.Tensor."""
from executorch.backends.nvidia.tensorrt.partitioner.operator_support import (
TensorRTOperatorSupport,
)
from executorch.exir import to_edge
from torch.export import export

class AddModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y

model = AddModel()
example_inputs = (torch.randn(2, 3), torch.randn(2, 3))
exported = export(model, example_inputs)
edge_program = to_edge(exported).exported_program()

# Create an instance to test the methods
op_support = TensorRTOperatorSupport()

for node in edge_program.graph_module.graph.nodes:
if node.op == "call_function" and "add" in node.name:
full_op_name = op_support._get_op_name(node)
formatted_name = op_support._remove_namespace(full_op_name)
self.assertEqual(formatted_name, "add.Tensor")
break
else:
self.fail("Could not find add node in graph")
Loading