diff --git a/backends/nvidia/tensorrt/converters/__init__.py b/backends/nvidia/tensorrt/converters/__init__.py index 69f5a9cb53d..58f6218e01a 100644 --- a/backends/nvidia/tensorrt/converters/__init__.py +++ b/backends/nvidia/tensorrt/converters/__init__.py @@ -11,6 +11,7 @@ from executorch.backends.nvidia.tensorrt.converters import add # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import addmm # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import batch_norm # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import bmm # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import clamp # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import concat # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import conv2d # noqa: F401 diff --git a/backends/nvidia/tensorrt/converters/bmm.py b/backends/nvidia/tensorrt/converters/bmm.py new file mode 100644 index 00000000000..abe7e23b3e7 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/bmm.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Converter for batch matrix multiplication operations.""" + +from typing import Any, Dict, Optional + +import tensorrt as trt +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import set_layer_name + + +@converter("aten.bmm.default") +def convert_bmm( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert aten.bmm.default to TensorRT MatrixMultiply. + + Performs batch matrix multiplication of two 3D tensors (B, M, K) @ (B, K, N) -> (B, M, N). + TensorRT's IMatrixMultiplyLayer supports batch matrix multiplication natively. + """ + lhs_arg = node.args[0] + rhs_arg = node.args[1] + + if lhs_arg not in input_map: + raise ValueError(f"Input node '{lhs_arg.name}' not found in input_map for bmm") + if rhs_arg not in input_map: + raise ValueError(f"Input node '{rhs_arg.name}' not found in input_map for bmm") + + lhs = input_map[lhs_arg] + rhs = input_map[rhs_arg] + + layer = network.add_matrix_multiply( + lhs, trt.MatrixOperation.NONE, rhs, trt.MatrixOperation.NONE + ) + set_layer_name(layer, node, "bmm") + + return layer.get_output(0) diff --git a/backends/nvidia/tensorrt/converters/targets.bzl b/backends/nvidia/tensorrt/converters/targets.bzl index 2427da1be9f..00648c5e49a 100644 --- a/backends/nvidia/tensorrt/converters/targets.bzl +++ b/backends/nvidia/tensorrt/converters/targets.bzl @@ -15,6 +15,7 @@ def define_common_targets(): "add.py", "addmm.py", "batch_norm.py", + "bmm.py", "clamp.py", "concat.py", "conv2d.py", diff --git a/examples/nvidia/tensorrt/export.py b/examples/nvidia/tensorrt/export.py index 5dc22bcfffc..a3b823cf402 100644 --- a/examples/nvidia/tensorrt/export.py +++ b/examples/nvidia/tensorrt/export.py @@ -38,6 +38,10 @@ "conv1d", "dl3", "edsr", + # "efficient_sam", # TODO: diff ~41 — likely bicubic interpolation decomposition or ConvTranspose2d issue + "emformer_join", + # "emformer_predict", # TODO: passes 1/3 seeds — precision sensitive with randomized inputs + "emformer_transcribe", "ic3", "linear", "mul", @@ -126,6 +130,7 @@ def _verify_correctness( et_module = _load_for_executorch_from_buffer(pte_bytes) + for seed in _TEST_SEEDS: inputs = _randomise_inputs(example_inputs, seed) diff --git a/examples/nvidia/tensorrt/tests/test_export.py b/examples/nvidia/tensorrt/tests/test_export.py index d96222e4d03..79b57d65b4c 100644 --- a/examples/nvidia/tensorrt/tests/test_export.py +++ b/examples/nvidia/tensorrt/tests/test_export.py @@ -40,11 +40,7 @@ def _populate_weight_cache() -> None: for env_var, filename in _WEIGHT_ENV_VARS.items(): src = os.environ.get(env_var) if src and os.path.isfile(src): - # dog.jpg goes to CWD (mv2 model downloads it there) - if filename == "dog.jpg": - dst = os.path.join(os.getcwd(), filename) - else: - dst = os.path.join(cache_dir, filename) + dst = os.path.join(cache_dir, filename) if not os.path.exists(dst): shutil.copy2(src, dst) logger.info(f"Cached {filename} from {src}") @@ -121,3 +117,15 @@ def test_ic3(self) -> None: def test_sdpa(self) -> None: _export_and_verify("sdpa") + + def test_emformer_join(self) -> None: + _export_and_verify("emformer_join") + + def test_softmax(self) -> None: + _export_and_verify("softmax") + + def test_mv3(self) -> None: + _export_and_verify("mv3") + + def test_ic3(self) -> None: + _export_and_verify("ic3")