diff --git a/backends/nvidia/tensorrt/backend.py b/backends/nvidia/tensorrt/backend.py index a5237e3082f..3484ee9ce30 100644 --- a/backends/nvidia/tensorrt/backend.py +++ b/backends/nvidia/tensorrt/backend.py @@ -6,8 +6,10 @@ """TensorRT backend implementation for ExecuTorch.""" -from typing import final, List +import logging +from typing import Any, Dict, final, List, Optional, Tuple +import torch from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, @@ -15,6 +17,18 @@ ) from torch.export.exported_program import ExportedProgram +from executorch.backends.nvidia.tensorrt.compile_spec import ( + TensorRTCompileSpec, + TensorRTPrecision, +) +from executorch.backends.nvidia.tensorrt.converter_registry import ( + lookup_converter, + needs_edge_program, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + @final class TensorRTBackend(BackendDetails): @@ -38,5 +52,305 @@ def preprocess( Returns: PreprocessResult containing the serialized TensorRT engine. """ + try: + import tensorrt as trt + except ImportError as e: + raise RuntimeError( + "TensorRT is not available. Please install TensorRT to use this backend." + ) from e + + # Import converters to trigger registration + from executorch.backends.nvidia.tensorrt import ( # noqa: F401 + converters as _converters, + ) + from executorch.backends.nvidia.tensorrt.converter_utils import ( + ConversionContext, + get_op_name, + get_trt_tensor, + torch_dtype_to_trt, + ) + + # Parse compile specs + spec = TensorRTCompileSpec.from_compile_specs(compile_specs) + if spec is None: + spec = TensorRTCompileSpec() + + graph_module = edge_program.graph_module + + # Identify input and output nodes + input_nodes = _get_input_nodes(graph_module, edge_program) + output_nodes = _get_output_nodes(graph_module) + + # Create TensorRT builder and network + trt_logger = trt.Logger(trt.Logger.WARNING) + builder = trt.Builder(trt_logger) + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + ) + if network is None: + raise RuntimeError("Failed to create TensorRT network") + + # Create conversion context for this build + ctx = ConversionContext(net=network) + + # Build the network + input_map = _add_network_inputs(network, input_nodes, torch_dtype_to_trt) + _process_graph_nodes( + graph_module, edge_program, network, input_map, get_trt_tensor, get_op_name, ctx + ) + _mark_network_outputs(network, output_nodes, input_map) + + # Configure and build engine + config = _create_builder_config(builder, spec, trt) + serialized_engine = builder.build_serialized_network(network, config) + + if serialized_engine is None: + raise RuntimeError("Failed to build TensorRT engine") + + return PreprocessResult(processed_bytes=bytes(serialized_engine)) + + +def _get_input_nodes( + graph_module: torch.fx.GraphModule, + exported_program: ExportedProgram, +) -> List[torch.fx.Node]: + """Get graph input placeholder nodes (excluding parameters/buffers).""" + input_nodes = [] + for node in graph_module.graph.nodes: + if node.op == "placeholder": + if not _is_param_or_buffer(node, exported_program): + input_nodes.append(node) + return input_nodes + + +def _get_output_nodes(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]: + """Get nodes that are graph outputs.""" + output_nodes = [] + for node in graph_module.graph.nodes: + if node.op == "output": + for arg in node.args: + if isinstance(arg, (list, tuple)): + output_nodes.extend( + item for item in arg if isinstance(item, torch.fx.Node) + ) + elif isinstance(arg, torch.fx.Node): + output_nodes.append(arg) + return output_nodes + + +def _is_param_or_buffer( + node: torch.fx.Node, exported_program: ExportedProgram +) -> bool: + """Check if a placeholder node is a parameter or buffer.""" + if node.op != "placeholder": + return False + + if hasattr(exported_program, "state_dict"): + if node.name in exported_program.state_dict: + return True + + if hasattr(exported_program, "graph_signature"): + sig = exported_program.graph_signature + if hasattr(sig, "inputs_to_parameters"): + if node.name in sig.inputs_to_parameters: + return True + if hasattr(sig, "inputs_to_buffers"): + if node.name in sig.inputs_to_buffers: + return True + + return False + + +def _get_tensor_shape_and_dtype( + node: torch.fx.Node, +) -> Tuple[Optional[Tuple[int, ...]], Optional[torch.dtype]]: + """Extract tensor shape and dtype from node metadata.""" + if "val" in node.meta: + val = node.meta["val"] + if isinstance(val, torch.Tensor): + return tuple(val.shape), val.dtype + if hasattr(val, "shape") and hasattr(val, "dtype"): + return tuple(val.shape), val.dtype + return None, None + + +def _get_attr_value( + graph_module: torch.fx.GraphModule, attr_name: str +) -> Optional[torch.Tensor]: + """Get attribute value from graph module.""" + try: + parts = attr_name.split(".") + obj = graph_module + for part in parts: + obj = getattr(obj, part) + if isinstance(obj, torch.Tensor): + return obj + if isinstance(obj, torch.nn.Parameter): + return obj.data + return None + except AttributeError: + return None + + +def _add_network_inputs( + network: Any, + input_nodes: List[torch.fx.Node], + dtype_converter: Any, +) -> Dict[torch.fx.Node, Any]: + """Add input tensors to TensorRT network.""" + input_map: Dict[torch.fx.Node, Any] = {} + + for input_node in input_nodes: + shape, dtype = _get_tensor_shape_and_dtype(input_node) + if shape is None: + raise RuntimeError( + f"Cannot determine shape for input node: {input_node.name}" + ) + + trt_dtype = dtype_converter(dtype if dtype else torch.float32) + trt_input = network.add_input( + name=input_node.name, + dtype=trt_dtype, + shape=shape, + ) + if trt_input is None: + raise RuntimeError(f"Failed to add input to network: {input_node.name}") + + input_map[input_node] = trt_input + + return input_map + + +def _process_graph_nodes( + graph_module: torch.fx.GraphModule, + exported_program: ExportedProgram, + network: Any, + input_map: Dict[torch.fx.Node, Any], + get_trt_tensor_fn: Any, + get_op_name_fn: Any, + ctx: Any = None, +) -> None: + """Process graph nodes and convert to TensorRT layers. + + Args: + graph_module: The FX graph module to process. + exported_program: The ExportedProgram for weight extraction. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + get_trt_tensor_fn: Function to create TensorRT constant tensors. + get_op_name_fn: Function to extract operation name from nodes. + ctx: Optional ConversionContext for unique layer naming. + """ + for node in graph_module.graph.nodes: + if node.op == "call_function": + op_name = get_op_name_fn(node) + + converter = lookup_converter(op_name) + if converter is None: + raise RuntimeError(f"No converter registered for operation: {op_name}") + + # Check if converter needs edge_program for weight extraction + if needs_edge_program(op_name): + output_tensor = converter(node, network, input_map, exported_program, ctx) + else: + output_tensor = converter(node, network, input_map, ctx) + + input_map[node] = output_tensor + + elif node.op == "get_attr": + attr_name = node.target + param = _get_attr_value(graph_module, attr_name) + if param is not None: + input_map[node] = get_trt_tensor_fn( + network, param, f"param_{node.name}" + ) + + +def _mark_network_outputs( + network: Any, + output_nodes: List[torch.fx.Node], + input_map: Dict[torch.fx.Node, Any], +) -> None: + """Mark network outputs in TensorRT network.""" + for output_node in output_nodes: + if output_node not in input_map: + raise RuntimeError( + f"Output node not found in input_map: {output_node.name}" + ) + + output_tensor = input_map[output_node] + if hasattr(output_tensor, "name"): + output_tensor.name = f"output_{output_node.name}" + network.mark_output(output_tensor) + + +def _create_builder_config(builder: Any, spec: TensorRTCompileSpec, trt: Any) -> Any: + """Create and configure TensorRT builder config.""" + config = builder.create_builder_config() + if config is None: + raise RuntimeError("Failed to create TensorRT builder config") + + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, spec.workspace_size) + + # Disable TF32 for strict FP32 precision on Ampere+ GPUs. + if hasattr(trt.BuilderFlag, "TF32"): + config.clear_flag(trt.BuilderFlag.TF32) + + # Report build progress if TRT supports IProgressMonitor. + if hasattr(trt, "IProgressMonitor"): + + class _ProgressMonitor(trt.IProgressMonitor): + def __init__(self): + self._seen = set() + # Report build progress if TRT supports IProgressMonitor. + if hasattr(trt, "IProgressMonitor"): + + class _ProgressMonitor(trt.IProgressMonitor): + def __init__(self): + trt.IProgressMonitor.__init__(self) + self._seen = set() + + def phase_start(self, phase_name, parent_phase, num_steps): + key = (phase_name, parent_phase) + if key not in self._seen: + self._seen.add(key) + indent = " " if parent_phase else " " + print(f"{indent}TRT: {phase_name}", flush=True) + + def step_complete(self, phase_name, step): + return True + + def phase_finish(self, phase_name): + pass + + config.progress_monitor = _ProgressMonitor() + + # TensorRT 10.6+ enables WEIGHT_STREAMING by default, which generates + # weight-separated plan files that require IStreamReader for deserialization. + # We disable this flag to generate standard plan files that can be + # deserialized with the simpler deserializeCudaEngine(data, size) API. + if hasattr(trt.BuilderFlag, "WEIGHT_STREAMING"): + config.clear_flag(trt.BuilderFlag.WEIGHT_STREAMING) + + if spec.precision == TensorRTPrecision.FP16: + if builder.platform_has_fast_fp16: + config.set_flag(trt.BuilderFlag.FP16) + else: + logger.warning("FP16 not supported on this platform, using FP32") + + if spec.precision == TensorRTPrecision.INT8: + if builder.platform_has_fast_int8: + config.set_flag(trt.BuilderFlag.INT8) + else: + logger.warning("INT8 not supported on this platform, using FP32") + + if spec.strict_type_constraints: + config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + if spec.dla_core >= 0: + config.default_device_type = trt.DeviceType.DLA + config.DLA_core = spec.dla_core + if spec.allow_gpu_fallback: + config.set_flag(trt.BuilderFlag.GPU_FALLBACK) - return PreprocessResult(processed_bytes=b"") + return config diff --git a/backends/nvidia/tensorrt/targets.bzl b/backends/nvidia/tensorrt/targets.bzl index 695b55798f6..49332ddbc25 100644 --- a/backends/nvidia/tensorrt/targets.bzl +++ b/backends/nvidia/tensorrt/targets.bzl @@ -15,6 +15,11 @@ def define_common_targets(): visibility = ["PUBLIC"], deps = [ "//caffe2:torch", + "//deeplearning/trt/python:py_tensorrt", + "//executorch/backends/nvidia/tensorrt:compile_spec", + "//executorch/backends/nvidia/tensorrt:converter_registry", + "//executorch/backends/nvidia/tensorrt:converter_utils", + "//executorch/backends/nvidia/tensorrt/converters:converters", "//executorch/exir/backend:backend_details", ], ) @@ -27,6 +32,8 @@ def define_common_targets(): visibility = ["PUBLIC"], deps = [ ":backend", + "//executorch/backends/nvidia/tensorrt:compile_spec", + "//executorch/backends/nvidia/tensorrt/partitioner:partitioner", ], ) diff --git a/backends/nvidia/tensorrt/test/TARGETS b/backends/nvidia/tensorrt/test/TARGETS index f2c1738acd7..5dfd12e9470 100644 --- a/backends/nvidia/tensorrt/test/TARGETS +++ b/backends/nvidia/tensorrt/test/TARGETS @@ -5,9 +5,6 @@ oncall("executorch") define_common_targets() -# GPU-dependent tests: these import TensorRT which requires CUDA to initialize. -# python_unittest_remote_gpu routes them to GPU-equipped CI workers. - python_unittest_remote_gpu( name = "test_converter_registry", srcs = [ @@ -32,3 +29,17 @@ python_unittest_remote_gpu( "//executorch/exir:lib", ], ) + +python_unittest_remote_gpu( + name = "test_preprocess", + srcs = [ + "test_preprocess.py", + ], + deps = [ + "//caffe2:torch", + "//deeplearning/trt/python:py_tensorrt", + "//executorch/backends/nvidia/tensorrt:compile_spec", + "//executorch/backends/nvidia/tensorrt:tensorrt", + "//executorch/exir:lib", + ], +) diff --git a/backends/nvidia/tensorrt/test/test_preprocess.py b/backends/nvidia/tensorrt/test/test_preprocess.py new file mode 100644 index 00000000000..5794114c55c --- /dev/null +++ b/backends/nvidia/tensorrt/test/test_preprocess.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for TensorRT preprocess and backend functionality.""" + +import unittest + +import torch + +# Check TensorRT availability at module level for skip decorators +_TENSORRT_AVAILABLE = False +try: + import tensorrt # noqa: F401 + + _TENSORRT_AVAILABLE = True +except ImportError: + pass + + +def _requires_tensorrt(test_class): + """Class decorator to skip all tests in a class if TensorRT is not available.""" + if not _TENSORRT_AVAILABLE: + return unittest.skip("TensorRT is not available")(test_class) + return test_class + + +class _AddModel(torch.nn.Module): + """Simple add model for testing.""" + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + +class TensorRTBackendTest(unittest.TestCase): + """Tests for TensorRTBackend class structure.""" + + @unittest.skipUnless(_TENSORRT_AVAILABLE, "TensorRT is not available") + def test_backend_import(self) -> None: + """Test that TensorRTBackend can be imported.""" + from executorch.backends.nvidia.tensorrt.backend import TensorRTBackend + + self.assertIsNotNone(TensorRTBackend) + self.assertTrue(hasattr(TensorRTBackend, "preprocess")) + + @unittest.skipUnless(_TENSORRT_AVAILABLE, "TensorRT is not available") + def test_preprocess_signature(self) -> None: + """Test that preprocess has the correct signature.""" + import inspect + + from executorch.backends.nvidia.tensorrt.backend import TensorRTBackend + + sig = inspect.signature(TensorRTBackend.preprocess) + params = list(sig.parameters.keys()) + self.assertIn("edge_program", params) + self.assertIn("compile_specs", params) + + +@_requires_tensorrt +class TensorRTPreprocessTest(unittest.TestCase): + """Integration tests for TensorRT preprocess (requires TensorRT).""" + + def test_preprocess_add_model(self) -> None: + """Test that preprocess produces non-empty bytes for add model.""" + from executorch.backends.nvidia.tensorrt.compile_spec import TensorRTCompileSpec + from executorch.backends.nvidia.tensorrt.backend import TensorRTBackend + + model = _AddModel() + model.eval() + + example_inputs = (torch.randn(2, 3), torch.randn(2, 3)) + exported = torch.export.export(model, example_inputs) + + compile_spec = TensorRTCompileSpec() + result = TensorRTBackend.preprocess(exported, compile_spec.to_compile_specs()) + + self.assertIsNotNone(result) + self.assertIsNotNone(result.processed_bytes) + self.assertGreater(len(result.processed_bytes), 0) + + def test_preprocess_empty_compile_specs(self) -> None: + """Test that preprocess works with empty compile specs (uses defaults).""" + from executorch.backends.nvidia.tensorrt.backend import TensorRTBackend + + model = _AddModel() + model.eval() + + example_inputs = (torch.randn(2, 3), torch.randn(2, 3)) + exported = torch.export.export(model, example_inputs) + + result = TensorRTBackend.preprocess(exported, []) + + self.assertIsNotNone(result) + self.assertGreater(len(result.processed_bytes), 0) + + def test_preprocess_different_shapes(self) -> None: + """Test preprocess with different input shapes.""" + from executorch.backends.nvidia.tensorrt.compile_spec import TensorRTCompileSpec + from executorch.backends.nvidia.tensorrt.backend import TensorRTBackend + + model = _AddModel() + model.eval() + + example_inputs = (torch.randn(4, 8, 16), torch.randn(4, 8, 16)) + exported = torch.export.export(model, example_inputs) + + compile_spec = TensorRTCompileSpec() + result = TensorRTBackend.preprocess(exported, compile_spec.to_compile_specs()) + + self.assertIsNotNone(result) + self.assertGreater(len(result.processed_bytes), 0)