From 22016a88a055578593dbdab9ef63b87e1550526b Mon Sep 17 00:00:00 2001 From: shoumikhin Date: Thu, 5 Mar 2026 09:58:51 -0800 Subject: [PATCH] [executorch][nvidia][tensorrt][16/n] Enable MobileNetV3 model support Add converters and optimizations to enable MobileNetV3 model support with the TensorRT backend. Differential Revision: [D93275043](https://our.internmc.facebook.com/intern/diff/D93275043/) [ghstack-poisoned] --- backends/nvidia/tensorrt/backend.py | 68 ++ .../nvidia/tensorrt/converters/__init__.py | 18 + .../nvidia/tensorrt/converters/activations.py | 1016 +++++++++++++++++ .../nvidia/tensorrt/converters/batch_norm.py | 191 ++++ backends/nvidia/tensorrt/converters/clamp.py | 253 ++++ backends/nvidia/tensorrt/converters/concat.py | 553 +++++++++ backends/nvidia/tensorrt/converters/conv2d.py | 314 +++++ .../tensorrt/converters/dim_order_ops.py | 233 ++++ .../nvidia/tensorrt/converters/getitem.py | 110 ++ backends/nvidia/tensorrt/converters/linear.py | 255 +++++ .../nvidia/tensorrt/converters/pooling.py | 506 ++++++++ .../nvidia/tensorrt/converters/reduction.py | 262 +++++ .../nvidia/tensorrt/converters/reshape.py | 868 ++++++++++++++ .../nvidia/tensorrt/converters/targets.bzl | 11 + examples/nvidia/tensorrt/export.py | 1 + examples/nvidia/tensorrt/tests/test_export.py | 3 + 16 files changed, 4662 insertions(+) create mode 100644 backends/nvidia/tensorrt/converters/activations.py create mode 100644 backends/nvidia/tensorrt/converters/batch_norm.py create mode 100644 backends/nvidia/tensorrt/converters/clamp.py create mode 100644 backends/nvidia/tensorrt/converters/concat.py create mode 100644 backends/nvidia/tensorrt/converters/conv2d.py create mode 100644 backends/nvidia/tensorrt/converters/dim_order_ops.py create mode 100644 backends/nvidia/tensorrt/converters/getitem.py create mode 100644 backends/nvidia/tensorrt/converters/linear.py create mode 100644 backends/nvidia/tensorrt/converters/pooling.py create mode 100644 backends/nvidia/tensorrt/converters/reduction.py create mode 100644 backends/nvidia/tensorrt/converters/reshape.py diff --git a/backends/nvidia/tensorrt/backend.py b/backends/nvidia/tensorrt/backend.py index 29185b8559a..5fdfeeb0d4a 100644 --- a/backends/nvidia/tensorrt/backend.py +++ b/backends/nvidia/tensorrt/backend.py @@ -30,6 +30,9 @@ TensorRTBlobMetadata, TensorRTIOBinding, ) +from executorch.backends.nvidia.tensorrt.converters import ( + clear_converter_weight_storage, +) logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -100,6 +103,10 @@ def preprocess( # Build the network input_map = _add_network_inputs(network, input_nodes, torch_dtype_to_trt) + # Add params/buffers as constant tensors + _add_params_to_input_map( + graph_module, edge_program, network, input_map, get_trt_tensor + ) _process_graph_nodes( graph_module, edge_program, network, input_map, get_trt_tensor, get_op_name, ctx ) @@ -173,6 +180,67 @@ def _is_param_or_buffer( return False +def _add_params_to_input_map( + graph_module: torch.fx.GraphModule, + exported_program: ExportedProgram, + network: Any, + input_map: Dict[torch.fx.Node, Any], + get_trt_tensor_fn: Any, +) -> None: + """Add parameters and buffers as constant TensorRT tensors to input_map. + + In ExecuTorch's edge dialect, parameters are often "lifted" as placeholder + inputs rather than get_attr nodes. This function identifies these placeholder + nodes that represent parameters/buffers and adds them to input_map as + TensorRT constant tensors. + """ + for node in graph_module.graph.nodes: + if node.op == "placeholder": + # Skip if already in input_map (it's a real input, not a param) + if node in input_map: + continue + + param_tensor = None + + # Try to get from state_dict first + if hasattr(exported_program, "state_dict"): + if node.name in exported_program.state_dict: + param_tensor = exported_program.state_dict[node.name] + + # Try to get from graph_signature mapping + if param_tensor is None and hasattr(exported_program, "graph_signature"): + sig = exported_program.graph_signature + param_name = None + if hasattr(sig, "inputs_to_parameters"): + param_name = sig.inputs_to_parameters.get(node.name) + if param_name is None and hasattr(sig, "inputs_to_buffers"): + param_name = sig.inputs_to_buffers.get(node.name) + + if param_name is not None and hasattr(exported_program, "state_dict"): + param_tensor = exported_program.state_dict.get(param_name) + + # If we found a parameter tensor, add it to input_map + if param_tensor is not None: + if isinstance(param_tensor, torch.nn.Parameter): + param_tensor = param_tensor.data + if isinstance(param_tensor, torch.Tensor): + # Convert int64/int32 tensors to float32 for TensorRT compatibility + # These are often used in elementwise operations with float tensors + # (e.g., batch norm statistics in MobileNetV3) + original_dtype = param_tensor.dtype + if param_tensor.dtype in (torch.int32, torch.int64): + param_tensor = param_tensor.float() + logger.debug( + f"Converting param {node.name} from {original_dtype} to float32 " + f"for TensorRT compatibility" + ) + elif param_tensor.dtype == torch.float64: + param_tensor = param_tensor.float() + input_map[node] = get_trt_tensor_fn( + network, param_tensor, f"param_{node.name}" + ) + + def _get_tensor_shape_and_dtype( node: torch.fx.Node, ) -> Tuple[Optional[Tuple[int, ...]], Optional[torch.dtype]]: diff --git a/backends/nvidia/tensorrt/converters/__init__.py b/backends/nvidia/tensorrt/converters/__init__.py index f529dbef65b..7db8177ae54 100644 --- a/backends/nvidia/tensorrt/converters/__init__.py +++ b/backends/nvidia/tensorrt/converters/__init__.py @@ -7,11 +7,29 @@ """TensorRT converters for ExecuTorch operations.""" # Import converters to trigger registration via @converter decorator +from executorch.backends.nvidia.tensorrt.converters import activations # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import add # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import addmm # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import batch_norm # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import clamp # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import concat # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import conv2d # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import dim_order_ops # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import div # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import getitem # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import linear # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import mm # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import mul # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import permute_copy # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import pooling # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import reduction # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import relu # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import reshape # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import sub # noqa: F401 + + +def clear_converter_weight_storage() -> None: + """Clear weight storage to free memory after engine build.""" + conv2d.clear_weight_storage() + batch_norm.clear_weight_storage() + linear.clear_weight_storage() diff --git a/backends/nvidia/tensorrt/converters/activations.py b/backends/nvidia/tensorrt/converters/activations.py new file mode 100644 index 00000000000..5029ea52f6a --- /dev/null +++ b/backends/nvidia/tensorrt/converters/activations.py @@ -0,0 +1,1016 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for Activation Operations. + +This module provides converters for PyTorch activation operations to TensorRT +activation layers. + +Supported operations: +- aten.sigmoid.default: Sigmoid activation +- aten.tanh.default: Tanh activation +- aten.gelu.default: GELU activation +- aten.silu.default: SiLU/Swish activation (x * sigmoid(x)) +- aten.softmax.int: Softmax (uses dedicated layer, not add_activation) +- aten.hardswish.default: Hard-swish (x * relu6(x + 3) / 6) - critical for MobileNetV3 +- aten.hardsigmoid.default: Hard-sigmoid (min(max((x + 3) / 6, 0), 1)) - critical for SE blocks + +Notes: +- Simple activations (sigmoid, tanh, gelu, silu) use network.add_activation() +- Softmax uses network.add_softmax() with axis configuration +- hardswish/hardsigmoid are decomposed into elementwise operations +""" + +import logging +from typing import Any, Dict, Optional + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_unary_activation(node: torch.fx.Node) -> bool: + """ + Validate that an activation node can be converted to TensorRT. + + Args: + node: FX node representing the activation operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_activation: node {node.name} is not call_function" + ) + return False + + args = node.args + # Minimum args: input + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_activation: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_softmax(node: torch.fx.Node) -> bool: + """ + Validate that a softmax node can be converted to TensorRT. + + Args: + node: FX node representing the softmax operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_softmax: node {node.name} is not call_function" + ) + return False + + args = node.args + # Args: input, dim + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_softmax: node {node.name} has insufficient args" + ) + return False + + # dim should be an int + dim = args[1] + if not isinstance(dim, int): + logger.debug( + f"[TensorRT] validate_softmax: dim must be int, got {type(dim)}" + ) + return False + + return True + + +def _create_scalar_constant( + network: Any, # trt.INetworkDefinition + scalar_value: float, + name_suffix: str, + target_ndims: int = 0, +) -> Any: # trt.ITensor + """ + Create a TensorRT constant tensor from a scalar value. + + Args: + network: TensorRT network definition. + scalar_value: The scalar value to create. + name_suffix: Suffix for the layer name. + target_ndims: Number of dimensions for the output shape. + The constant will be created with shape [1, 1, ..., 1] (target_ndims ones). + If 0, creates a scalar constant. + + Returns: + TensorRT constant tensor for broadcasting. + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy required") from e + + # Create shape with appropriate number of dimensions for broadcasting + if target_ndims > 0: + shape = [1] * target_ndims + scalar_array = np.full(shape, scalar_value, dtype=np.float32) + else: + scalar_array = np.array([scalar_value], dtype=np.float32) + shape = [1] + + weights = trt.Weights(scalar_array) + layer = network.add_constant(trt.Dims(shape), weights) + + if layer is None: + raise RuntimeError(f"Failed to create constant layer: {name_suffix}") + + layer.name = f"scalar_const_{name_suffix}" + return layer.get_output(0) + + +@converter("aten.sigmoid.default", validator_fn=validate_unary_activation) +def convert_sigmoid( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch sigmoid to TensorRT activation layer. + + Args: + node: FX node representing the sigmoid operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_sigmoid") from e + + logger.debug(f"[TensorRT] Converting sigmoid node: {node.name}") + + input_node = node.args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to sigmoid must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + layer = network.add_activation(input_trt, trt.ActivationType.SIGMOID) + if layer is None: + raise RuntimeError(f"Failed to create sigmoid layer for node {node.name}") + + layer.name = f"sigmoid_{node.name}" + logger.debug(f"[TensorRT] Created sigmoid layer: {layer.name}") + + return layer.get_output(0) + + +@converter("aten.tanh.default", validator_fn=validate_unary_activation) +def convert_tanh( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch tanh to TensorRT activation layer. + + Args: + node: FX node representing the tanh operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_tanh") from e + + logger.debug(f"[TensorRT] Converting tanh node: {node.name}") + + input_node = node.args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to tanh must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + layer = network.add_activation(input_trt, trt.ActivationType.TANH) + if layer is None: + raise RuntimeError(f"Failed to create tanh layer for node {node.name}") + + layer.name = f"tanh_{node.name}" + logger.debug(f"[TensorRT] Created tanh layer: {layer.name}") + + return layer.get_output(0) + + +@converter("aten.gelu.default", validator_fn=validate_unary_activation) +def convert_gelu( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch GELU to TensorRT activation layer. + + Args: + node: FX node representing the gelu operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_gelu") from e + + logger.debug(f"[TensorRT] Converting gelu node: {node.name}") + + input_node = node.args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to gelu must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # TensorRT 8.6+ has native GELU support via GELU_ERF or GELU_TANH + # For older versions, we need to fall back to a manual implementation + try: + layer = network.add_activation(input_trt, trt.ActivationType.GELU_ERF) + except AttributeError: + # Fallback: GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) + # This is the tanh approximation of GELU + import math + + # Constants + sqrt_2_pi = math.sqrt(2.0 / math.pi) # ~0.7979 + coeff = 0.044715 + + # Create constant tensors matching input shape + input_ndims = len(input_trt.shape) + + def create_scalar(val: float, name: str) -> Any: + const_shape = tuple([1] * input_ndims) + const_data = torch.tensor([val], dtype=torch.float32).numpy() + const_weights = trt.Weights(const_data) + layer = network.add_constant(const_shape, const_weights) + if layer is None: + raise RuntimeError(f"Failed to create constant {name}") + layer.name = f"gelu_const_{name}_{node.name}" + return layer.get_output(0) + + const_half = create_scalar(0.5, "half") + const_one = create_scalar(1.0, "one") + const_sqrt2pi = create_scalar(sqrt_2_pi, "sqrt2pi") + const_coeff = create_scalar(coeff, "coeff") + + # x^3 + x_sq = network.add_elementwise(input_trt, input_trt, trt.ElementWiseOperation.PROD) + x_sq.name = f"gelu_x_sq_{node.name}" + x_cubed = network.add_elementwise(x_sq.get_output(0), input_trt, trt.ElementWiseOperation.PROD) + x_cubed.name = f"gelu_x_cubed_{node.name}" + + # 0.044715 * x^3 + coeff_x_cubed = network.add_elementwise(const_coeff, x_cubed.get_output(0), trt.ElementWiseOperation.PROD) + coeff_x_cubed.name = f"gelu_coeff_x_cubed_{node.name}" + + # x + 0.044715 * x^3 + x_plus_term = network.add_elementwise(input_trt, coeff_x_cubed.get_output(0), trt.ElementWiseOperation.SUM) + x_plus_term.name = f"gelu_x_plus_term_{node.name}" + + # sqrt(2/π) * (x + 0.044715 * x^3) + scaled = network.add_elementwise(const_sqrt2pi, x_plus_term.get_output(0), trt.ElementWiseOperation.PROD) + scaled.name = f"gelu_scaled_{node.name}" + + # tanh(...) + tanh_layer = network.add_activation(scaled.get_output(0), trt.ActivationType.TANH) + tanh_layer.name = f"gelu_tanh_{node.name}" + + # 1 + tanh(...) + one_plus_tanh = network.add_elementwise(const_one, tanh_layer.get_output(0), trt.ElementWiseOperation.SUM) + one_plus_tanh.name = f"gelu_one_plus_tanh_{node.name}" + + # x * (1 + tanh(...)) + x_times_term = network.add_elementwise(input_trt, one_plus_tanh.get_output(0), trt.ElementWiseOperation.PROD) + x_times_term.name = f"gelu_x_times_term_{node.name}" + + # 0.5 * x * (1 + tanh(...)) + layer = network.add_elementwise(const_half, x_times_term.get_output(0), trt.ElementWiseOperation.PROD) + layer.name = f"gelu_final_{node.name}" + + logger.debug(f"[TensorRT] Using GELU tanh approximation (GELU_ERF not available)") + return layer.get_output(0) + + if layer is None: + raise RuntimeError(f"Failed to create gelu layer for node {node.name}") + + layer.name = f"gelu_{node.name}" + logger.debug(f"[TensorRT] Created gelu layer: {layer.name}") + + return layer.get_output(0) + + +@converter("aten.silu.default", validator_fn=validate_unary_activation) +def convert_silu( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch SiLU (Swish) to TensorRT activation layer. + + SiLU is defined as: x * sigmoid(x) + + Args: + node: FX node representing the silu operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_silu") from e + + logger.debug(f"[TensorRT] Converting silu node: {node.name}") + + input_node = node.args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to silu must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # TensorRT has native Swish (SiLU) support + layer = network.add_activation(input_trt, trt.ActivationType.SWISH) + if layer is None: + raise RuntimeError(f"Failed to create silu layer for node {node.name}") + + layer.name = f"silu_{node.name}" + logger.debug(f"[TensorRT] Created silu layer: {layer.name}") + + return layer.get_output(0) + + +@converter("aten.softmax.int", "aten._softmax.default", validator_fn=validate_softmax) +def convert_softmax( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch softmax to TensorRT softmax layer. + + Note: Softmax uses network.add_softmax() instead of add_activation(). + + Args: + node: FX node representing the softmax operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_softmax") from e + + logger.debug(f"[TensorRT] Converting softmax node: {node.name}") + + input_node = node.args[0] + dim = node.args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to softmax must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Handle negative dim using TRT tensor shape. + ndim = len(input_trt.shape) + if dim < 0: + dim = ndim + dim + + # Create softmax layer (NOT add_activation) + layer = network.add_softmax(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create softmax layer for node {node.name}") + + # Set the axis for softmax + # TensorRT uses axes as a bitmask + layer.axes = 1 << dim + + layer.name = f"softmax_{node.name}" + logger.debug(f"[TensorRT] Created softmax layer: {layer.name}, axis={dim}") + + return layer.get_output(0) + + +@converter("aten.log_softmax.int", "aten._log_softmax.default", validator_fn=validate_softmax) +def convert_log_softmax( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch _log_softmax to TensorRT softmax + log layers. + + We decompose log_softmax into softmax followed by log operation, + reusing the existing converters. + + Args: + node: FX node representing the _log_softmax operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_log_softmax") from e + + logger.debug(f"[TensorRT] Converting _log_softmax node: {node.name}") + + input_node = node.args[0] + dim = node.args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to _log_softmax must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Handle negative dimension + input_shape = input_trt.shape + ndim = len(input_shape) + if dim < 0: + dim = ndim + dim + + # Step 1: Apply softmax + softmax_layer = network.add_softmax(input_trt) + if softmax_layer is None: + raise RuntimeError( + f"Failed to create softmax layer for _log_softmax node {node.name}" + ) + softmax_layer.axes = 1 << dim + softmax_layer.name = f"log_softmax_softmax_{node.name}" + softmax_output = softmax_layer.get_output(0) + + logger.debug( + f"[TensorRT] Created softmax layer for log_softmax: " + f"{softmax_layer.name}, axis={dim}" + ) + + # Step 2: Apply log + log_layer = network.add_unary(softmax_output, trt.UnaryOperation.LOG) + if log_layer is None: + raise RuntimeError( + f"Failed to create log layer for _log_softmax node {node.name}" + ) + log_layer.name = f"log_softmax_log_{node.name}" + + logger.debug(f"[TensorRT] Created log_softmax composite: {log_layer.name}") + + return log_layer.get_output(0) + + +@converter("aten.hardswish.default", "aten.hardswish_.default", validator_fn=validate_unary_activation) +def convert_hardswish( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch hardswish to TensorRT composite layers. + + Hardswish is defined as: x * relu6(x + 3) / 6 + Which is: x * min(max(x + 3, 0), 6) / 6 + + This is critical for MobileNetV3 support. + + Args: + node: FX node representing the hardswish operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_hardswish") from e + + logger.debug(f"[TensorRT] Converting hardswish node: {node.name}") + + input_node = node.args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to hardswish must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get input dimensions for proper broadcasting + input_ndims = len(input_trt.shape) + + # Step 1: Add 3 + const_3 = _create_scalar_constant( + network, 3.0, f"{node.name}_const3", target_ndims=input_ndims + ) + add_3_layer = network.add_elementwise( + input_trt, const_3, trt.ElementWiseOperation.SUM + ) + if add_3_layer is None: + raise RuntimeError(f"Failed to create add_3 layer for hardswish {node.name}") + add_3_layer.name = f"hardswish_add3_{node.name}" + add_3_output = add_3_layer.get_output(0) + + # Step 2: ReLU (max(x + 3, 0)) + relu_layer = network.add_activation(add_3_output, trt.ActivationType.RELU) + if relu_layer is None: + raise RuntimeError(f"Failed to create relu layer for hardswish {node.name}") + relu_layer.name = f"hardswish_relu_{node.name}" + relu_output = relu_layer.get_output(0) + + # Step 3: Min with 6 (clip to 6, implementing relu6) + const_6 = _create_scalar_constant( + network, 6.0, f"{node.name}_const6", target_ndims=input_ndims + ) + min_layer = network.add_elementwise( + relu_output, const_6, trt.ElementWiseOperation.MIN + ) + if min_layer is None: + raise RuntimeError(f"Failed to create min layer for hardswish {node.name}") + min_layer.name = f"hardswish_min6_{node.name}" + min_output = min_layer.get_output(0) + + # Step 4: Divide by 6 (reuse const_6) + div_layer = network.add_elementwise( + min_output, const_6, trt.ElementWiseOperation.DIV + ) + if div_layer is None: + raise RuntimeError(f"Failed to create div layer for hardswish {node.name}") + div_layer.name = f"hardswish_div6_{node.name}" + div_output = div_layer.get_output(0) + + # Step 5: Multiply by input (x * relu6(x + 3) / 6) + mul_layer = network.add_elementwise( + input_trt, div_output, trt.ElementWiseOperation.PROD + ) + if mul_layer is None: + raise RuntimeError(f"Failed to create mul layer for hardswish {node.name}") + mul_layer.name = f"hardswish_{node.name}" + + logger.debug(f"[TensorRT] Created hardswish composite: {mul_layer.name}") + + return mul_layer.get_output(0) + + +@converter("aten.hardsigmoid.default", validator_fn=validate_unary_activation) +def convert_hardsigmoid( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch hardsigmoid to TensorRT composite layers. + + Hardsigmoid is defined as: min(max((x + 3) / 6, 0), 1) + Which is: clip((x + 3) / 6, 0, 1) + + This is critical for MobileNetV3 Squeeze-and-Excitation blocks. + + Args: + node: FX node representing the hardsigmoid operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_hardsigmoid") from e + + logger.debug(f"[TensorRT] Converting hardsigmoid node: {node.name}") + + input_node = node.args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to hardsigmoid must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get input dimensions for proper broadcasting + input_ndims = len(input_trt.shape) + + # Step 1: Add 3 + const_3 = _create_scalar_constant( + network, 3.0, f"{node.name}_const3", target_ndims=input_ndims + ) + add_3_layer = network.add_elementwise( + input_trt, const_3, trt.ElementWiseOperation.SUM + ) + if add_3_layer is None: + raise RuntimeError(f"Failed to create add_3 layer for hardsigmoid {node.name}") + add_3_layer.name = f"hardsigmoid_add3_{node.name}" + add_3_output = add_3_layer.get_output(0) + + # Step 2: Divide by 6 + const_6 = _create_scalar_constant( + network, 6.0, f"{node.name}_const6", target_ndims=input_ndims + ) + div_layer = network.add_elementwise( + add_3_output, const_6, trt.ElementWiseOperation.DIV + ) + if div_layer is None: + raise RuntimeError(f"Failed to create div layer for hardsigmoid {node.name}") + div_layer.name = f"hardsigmoid_div6_{node.name}" + div_output = div_layer.get_output(0) + + # Step 3: ReLU (max((x+3)/6, 0)) + relu_layer = network.add_activation(div_output, trt.ActivationType.RELU) + if relu_layer is None: + raise RuntimeError(f"Failed to create relu layer for hardsigmoid {node.name}") + relu_layer.name = f"hardsigmoid_relu_{node.name}" + relu_output = relu_layer.get_output(0) + + # Step 4: Min with 1 (clip to [0, 1]) + const_1 = _create_scalar_constant( + network, 1.0, f"{node.name}_const1", target_ndims=input_ndims + ) + min_layer = network.add_elementwise( + relu_output, const_1, trt.ElementWiseOperation.MIN + ) + if min_layer is None: + raise RuntimeError(f"Failed to create min layer for hardsigmoid {node.name}") + min_layer.name = f"hardsigmoid_{node.name}" + + logger.debug(f"[TensorRT] Created hardsigmoid composite: {min_layer.name}") + + return min_layer.get_output(0) + + +def validate_clamp(node: torch.fx.Node) -> bool: + """ + Validate that a clamp node can be converted to TensorRT. + + Args: + node: FX node representing the clamp operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + # clamp takes at least 1 arg (input), and optionally min and max + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_clamp: node {node.name} has insufficient args" + ) + return False + + return True + + +@converter( + "aten.clamp.default", + "aten.clamp.Tensor", + "aten.clamp_min.default", + "aten.clamp_max.default", + validator_fn=validate_clamp, +) +def convert_clamp( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch clamp to TensorRT elementwise operations. + + Clamp is defined as: output = min(max(input, min_val), max_val) + + Args: + node: FX node representing the clamp operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: ExportedProgram for extracting parameters. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_clamp") from e + + logger.debug(f"[TensorRT] Converting clamp node: {node.name}") + + args = node.args + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to clamp must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get min and max values + op_name = str(node.target) + min_val = None + max_val = None + + if "clamp_min" in op_name: + # clamp_min only has min + min_val = args[1] if len(args) > 1 else None + elif "clamp_max" in op_name: + # clamp_max only has max + max_val = args[1] if len(args) > 1 else None + else: + # Regular clamp has both + min_val = args[1] if len(args) > 1 else None + max_val = args[2] if len(args) > 2 else None + + # Also check kwargs + min_val = node.kwargs.get("min", min_val) + max_val = node.kwargs.get("max", max_val) + + output = input_trt + + # Get input dimensions for proper broadcasting + input_ndims = len(input_trt.shape) + + # Apply min (max with min_val) + if min_val is not None: + min_val_float = float(min_val) if isinstance(min_val, (int, float)) else 0.0 + min_const = _create_scalar_constant( + network, min_val_float, f"{node.name}_min", target_ndims=input_ndims + ) + max_layer = network.add_elementwise( + output, min_const, trt.ElementWiseOperation.MAX + ) + if max_layer is None: + raise RuntimeError(f"Failed to create max layer for clamp {node.name}") + max_layer.name = f"clamp_min_{node.name}" + output = max_layer.get_output(0) + + # Apply max (min with max_val) + if max_val is not None: + max_val_float = float(max_val) if isinstance(max_val, (int, float)) else 1.0 + max_const = _create_scalar_constant( + network, max_val_float, f"{node.name}_max", target_ndims=input_ndims + ) + min_layer = network.add_elementwise( + output, max_const, trt.ElementWiseOperation.MIN + ) + if min_layer is None: + raise RuntimeError(f"Failed to create min layer for clamp {node.name}") + min_layer.name = f"clamp_max_{node.name}" + output = min_layer.get_output(0) + + logger.debug(f"[TensorRT] Created clamp layers for node: {node.name}") + + return output + + +def validate_hardtanh(node: torch.fx.Node) -> bool: + """ + Validate that a hardtanh node can be converted to TensorRT. + + Args: + node: FX node representing the hardtanh operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_hardtanh: node {node.name} has insufficient args" + ) + return False + + return True + + +@converter( + "aten.hardtanh.default", + "aten.hardtanh_.default", + validator_fn=validate_hardtanh, +) +def convert_hardtanh( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch hardtanh to TensorRT elementwise operations. + + Hardtanh is defined as: output = clamp(input, min_val, max_val) + This is used by ReLU6 (hardtanh with min=0, max=6) in MobileNetV2. + + Args: + node: FX node representing the hardtanh operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: ExportedProgram for extracting parameters. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_hardtanh") from e + + logger.debug(f"[TensorRT] Converting hardtanh node: {node.name}") + + args = node.args + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to hardtanh must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get min and max values (default: min=-1, max=1) + min_val = args[1] if len(args) > 1 else node.kwargs.get("min_val", -1.0) + max_val = args[2] if len(args) > 2 else node.kwargs.get("max_val", 1.0) + + min_val_float = float(min_val) + max_val_float = float(max_val) + + # Get input dimensions for proper broadcasting + input_ndims = len(input_trt.shape) + + output = input_trt + + # Apply min (max with min_val) - clamp from below + min_const = _create_scalar_constant( + network, min_val_float, f"{node.name}_min", target_ndims=input_ndims + ) + max_layer = network.add_elementwise( + output, min_const, trt.ElementWiseOperation.MAX + ) + if max_layer is None: + raise RuntimeError(f"Failed to create max layer for hardtanh {node.name}") + max_layer.name = f"hardtanh_min_{node.name}" + output = max_layer.get_output(0) + + # Apply max (min with max_val) - clamp from above + max_const = _create_scalar_constant( + network, max_val_float, f"{node.name}_max", target_ndims=input_ndims + ) + min_layer = network.add_elementwise( + output, max_const, trt.ElementWiseOperation.MIN + ) + if min_layer is None: + raise RuntimeError(f"Failed to create min layer for hardtanh {node.name}") + min_layer.name = f"hardtanh_max_{node.name}" + output = min_layer.get_output(0) + + logger.debug( + f"[TensorRT] Created hardtanh layers for node: {node.name} " + f"(min={min_val_float}, max={max_val_float})" + ) + + return output + + +__all__ = [ + "convert_sigmoid", + "convert_tanh", + "convert_gelu", + "convert_silu", + "convert_softmax", + "convert_log_softmax", + "convert_hardswish", + "convert_hardsigmoid", + "convert_clamp", + "convert_hardtanh", + "convert_dropout", + "validate_unary_activation", + "validate_softmax", + "validate_clamp", + "validate_hardtanh", +] + + +def validate_dropout(node: torch.fx.Node) -> bool: + """Validate that a dropout node can be converted to TensorRT.""" + if node.op != "call_function": + return False + + # Dropout must have at least 1 arg (input) + if len(node.args) < 1: + return False + + return True + + +@converter( + "aten.dropout.default", + "aten.dropout_.default", + validator_fn=validate_dropout, +) +def convert_dropout( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch dropout to TensorRT (no-op in inference mode). + + Dropout is a no-op during inference - just pass through the input tensor. + + Args: + node: FX node representing the dropout operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor (same as input). + """ + logger.debug(f"[TensorRT] Converting dropout node: {node.name} (no-op)") + + input_node = node.args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to dropout must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + # Dropout is a no-op in inference mode - return input directly + return input_map[input_node] diff --git a/backends/nvidia/tensorrt/converters/batch_norm.py b/backends/nvidia/tensorrt/converters/batch_norm.py new file mode 100644 index 00000000000..597414debba --- /dev/null +++ b/backends/nvidia/tensorrt/converters/batch_norm.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""TensorRT Converter for Batch Normalization Operations.""" + +import logging +from typing import Any, Dict, Optional, Union + +import torch + +from executorch.backends.nvidia.tensorrt.converter_registry import converter + +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_buffer, + is_lifted_tensor_constant, + is_param, +) +from torch.export.exported_program import ExportedProgram + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_batch_norm(node: torch.fx.Node) -> bool: + """Validate that a batch norm node can be converted to TensorRT.""" + if node.op != "call_function": + return False + if len(node.args) < 5: + return False + return True + + +def _get_param_tensor( + exp_prog: Optional[ExportedProgram], + node: Any, +) -> Optional[torch.Tensor]: + """Extract a constant tensor from an ExportedProgram.""" + if node is None: + return None + if isinstance(node, torch.Tensor): + return node + if not isinstance(node, torch.fx.Node): + return None + + if exp_prog is not None: + if is_param(exp_prog, node): + return get_param(exp_prog, node) + elif is_buffer(exp_prog, node): + return get_buffer(exp_prog, node) + elif is_lifted_tensor_constant(exp_prog, node): + return get_lifted_tensor_constant(exp_prog, node) + + # Fallback for get_attr nodes + if isinstance(node, torch.fx.Node) and node.op == "get_attr": + if exp_prog is not None: + try: + target = node.target + if isinstance(target, str): + return getattr(exp_prog.graph_module, target) + except AttributeError: + pass + try: + if hasattr(node, "graph") and hasattr(node.graph, "owning_module"): + target = node.target + if isinstance(target, str): + return getattr(node.graph.owning_module, target) + except AttributeError: + pass + + return None + + +@converter( + "aten._native_batch_norm_legit.default", + "aten._native_batch_norm_legit_no_training.default", + "aten.batch_norm.default", + validator_fn=validate_batch_norm, + needs_edge_program=True, +) +def convert_batch_norm( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Union[ExportedProgram, torch.fx.GraphModule]] = None, + ctx: Any = None, +) -> Any: + """Convert PyTorch batch norm to TensorRT scale layer. + + Implements batch normalization as a fused scale operation: + output = scale * input + shift + where: + scale = gamma / sqrt(running_var + eps) + shift = beta - running_mean * scale + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT is required for convert_batch_norm.") from e + + args = node.args + kwargs = node.kwargs + + input_node = args[0] + weight_node = args[1] if len(args) > 1 else kwargs.get("weight", None) + bias_node = args[2] if len(args) > 2 else kwargs.get("bias", None) + running_mean_node = args[3] if len(args) > 3 else kwargs.get("running_mean", None) + running_var_node = args[4] if len(args) > 4 else kwargs.get("running_var", None) + + target_str = str(node.target) + if "no_training" in target_str: + eps = args[6] if len(args) > 6 else kwargs.get("eps", 1e-5) + else: + eps = args[7] if len(args) > 7 else kwargs.get("eps", 1e-5) + + if not isinstance(input_node, torch.fx.Node) or input_node not in input_map: + raise ValueError(f"Input node {input_node} not found in input_map") + + input_trt = input_map[input_node] + + exp_prog = edge_program if isinstance(edge_program, ExportedProgram) else None + weight_tensor = _get_param_tensor(exp_prog, weight_node) + bias_tensor = _get_param_tensor(exp_prog, bias_node) + running_mean_tensor = _get_param_tensor(exp_prog, running_mean_node) + running_var_tensor = _get_param_tensor(exp_prog, running_var_node) + + if running_mean_tensor is None: + raise ValueError(f"running_mean must be a constant tensor for {node.name}") + if running_var_tensor is None: + raise ValueError(f"running_var must be a constant tensor for {node.name}") + + mean_np = running_mean_tensor.detach().cpu().numpy().astype(np.float32) + var_np = running_var_tensor.detach().cpu().numpy().astype(np.float32) + + if weight_tensor is not None: + gamma_np = weight_tensor.detach().cpu().numpy().astype(np.float32) + else: + gamma_np = np.ones_like(mean_np, dtype=np.float32) + + if bias_tensor is not None: + beta_np = bias_tensor.detach().cpu().numpy().astype(np.float32) + else: + beta_np = np.zeros_like(mean_np, dtype=np.float32) + + num_channels = mean_np.shape[0] + + # Fuse BN into scale layer: y = scale * x + shift + fused_scale = np.ascontiguousarray( + (gamma_np / np.sqrt(var_np + eps)).astype(np.float32) + ) + fused_shift = np.ascontiguousarray( + (beta_np - mean_np * fused_scale).astype(np.float32) + ) + power_weights = np.ascontiguousarray(np.ones(num_channels, dtype=np.float32)) + + # Store arrays to prevent GC before engine build completes + if not hasattr(convert_batch_norm, '_weight_storage'): + convert_batch_norm._weight_storage = [] + convert_batch_norm._weight_storage.extend([fused_scale, fused_shift, power_weights]) + + scale_layer = network.add_scale( + input_trt, + trt.ScaleMode.CHANNEL, + shift=trt.Weights(fused_shift), + scale=trt.Weights(fused_scale), + power=trt.Weights(power_weights), + ) + if scale_layer is None: + raise RuntimeError(f"Failed to create Scale layer for {node.name}") + scale_layer.name = f"bn_scale_{node.name}" + return scale_layer.get_output(0) + + +def clear_weight_storage() -> None: + """Clear weight storage to free memory after engine build.""" + if hasattr(convert_batch_norm, '_weight_storage'): + convert_batch_norm._weight_storage.clear() + + +__all__ = [ + "clear_weight_storage", + "convert_batch_norm", + "validate_batch_norm", +] diff --git a/backends/nvidia/tensorrt/converters/clamp.py b/backends/nvidia/tensorrt/converters/clamp.py new file mode 100644 index 00000000000..e180016e1fc --- /dev/null +++ b/backends/nvidia/tensorrt/converters/clamp.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""TensorRT Converters for Clamp/Clip Operations. + +Supported operations: +- aten.clamp.default: Clamps all elements in input into the range [min, max] +- aten.clamp.Tensor: Clamps with tensor bounds +- aten.clip.default: Alias for clamp +- aten.hardtanh.default: Clamps between min_val and max_val (ReLU6 variant) + +TensorRT supports clamping via the IActivationLayer with CLIP activation type. +""" + +import logging +from typing import Any, Dict, Optional + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import ( + broadcast_tensors, + get_trt_tensor, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_clamp(node: torch.fx.Node) -> bool: + """Validate that a clamp node can be converted to TensorRT. + + Args: + node: FX node representing the clamp operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_clamp: node {node.name} is not call_function" + ) + return False + + args = node.args + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_clamp: node {node.name} has insufficient args" + ) + return False + + if not isinstance(args[0], torch.fx.Node): + logger.debug( + f"[TensorRT] validate_clamp: input is not a node, got {type(args[0])}" + ) + return False + + return True + + +@converter("aten.clamp.default", validator_fn=validate_clamp) +def convert_clamp( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """Convert PyTorch clamp to TensorRT. + + PyTorch signature: + aten.clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + + For constant min/max bounds, we use IActivationLayer with CLIP type. + For None bounds, we implement only the specified bound. + + Args: + node: FX node representing the clamp operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program for accessing weights/constants. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_clamp") from e + + logger.debug(f"[TensorRT] Converting clamp node: {node.name}") + + args = node.args + kwargs = node.kwargs + + input_node = args[0] + min_val = args[1] if len(args) > 1 else kwargs.get("min", None) + max_val = args[2] if len(args) > 2 else kwargs.get("max", None) + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to clamp must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + logger.debug(f"[TensorRT] clamp: min={min_val}, max={max_val}") + + result = input_trt + input_ndim = len(input_trt.shape) + + # Handle min bound (using max with min_val) + if min_val is not None: + min_const = get_trt_tensor(network, float(min_val), f"clamp_min_{node.name}") + # Broadcast constant to match input dimensions for elementwise operation + [min_const] = broadcast_tensors( + network, [min_const], input_ndim, f"clamp_min_{node.name}" + ) + layer_min = network.add_elementwise( + result, min_const, trt.ElementWiseOperation.MAX + ) + if layer_min is None: + raise RuntimeError(f"Failed to create clamp min layer for node {node.name}") + layer_min.name = f"clamp_min_{node.name}" + result = layer_min.get_output(0) + + # Handle max bound (using min with max_val) + if max_val is not None: + max_const = get_trt_tensor(network, float(max_val), f"clamp_max_{node.name}") + # Broadcast constant to match input dimensions for elementwise operation + [max_const] = broadcast_tensors( + network, [max_const], input_ndim, f"clamp_max_{node.name}" + ) + layer_max = network.add_elementwise( + result, max_const, trt.ElementWiseOperation.MIN + ) + if layer_max is None: + raise RuntimeError(f"Failed to create clamp max layer for node {node.name}") + layer_max.name = f"clamp_max_{node.name}" + result = layer_max.get_output(0) + + if min_val is None and max_val is None: + # No clamping needed, return input as-is + logger.warning( + f"[TensorRT] clamp node {node.name} has no min or max, returning input" + ) + + logger.debug(f"[TensorRT] Created clamp layers for node: {node.name}") + + return result + + +@converter("aten.clamp_min.default", validator_fn=validate_clamp) +def convert_clamp_min( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """Convert PyTorch clamp_min to TensorRT. + + PyTorch signature: + aten.clamp_min(Tensor self, Scalar min) -> Tensor + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_clamp_min") from e + + logger.debug(f"[TensorRT] Converting clamp_min node: {node.name}") + + args = node.args + input_node = args[0] + min_val = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to clamp_min must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + input_ndim = len(input_trt.shape) + + min_const = get_trt_tensor(network, float(min_val), f"clamp_min_{node.name}") + # Broadcast constant to match input dimensions for elementwise operation + [min_const] = broadcast_tensors( + network, [min_const], input_ndim, f"clamp_min_{node.name}" + ) + layer = network.add_elementwise(input_trt, min_const, trt.ElementWiseOperation.MAX) + if layer is None: + raise RuntimeError(f"Failed to create clamp_min layer for node {node.name}") + layer.name = f"clamp_min_{node.name}" + + logger.debug(f"[TensorRT] Created clamp_min layer: {layer.name}") + + return layer.get_output(0) + + +@converter("aten.clamp_max.default", validator_fn=validate_clamp) +def convert_clamp_max( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """Convert PyTorch clamp_max to TensorRT. + + PyTorch signature: + aten.clamp_max(Tensor self, Scalar max) -> Tensor + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_clamp_max") from e + + logger.debug(f"[TensorRT] Converting clamp_max node: {node.name}") + + args = node.args + input_node = args[0] + max_val = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to clamp_max must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + input_ndim = len(input_trt.shape) + + max_const = get_trt_tensor(network, float(max_val), f"clamp_max_{node.name}") + # Broadcast constant to match input dimensions for elementwise operation + [max_const] = broadcast_tensors( + network, [max_const], input_ndim, f"clamp_max_{node.name}" + ) + layer = network.add_elementwise(input_trt, max_const, trt.ElementWiseOperation.MIN) + if layer is None: + raise RuntimeError(f"Failed to create clamp_max layer for node {node.name}") + layer.name = f"clamp_max_{node.name}" + + logger.debug(f"[TensorRT] Created clamp_max layer: {layer.name}") + + return layer.get_output(0) + + +__all__ = [ + "convert_clamp", + "convert_clamp_min", + "convert_clamp_max", + "validate_clamp", +] diff --git a/backends/nvidia/tensorrt/converters/concat.py b/backends/nvidia/tensorrt/converters/concat.py new file mode 100644 index 00000000000..e751e722b8b --- /dev/null +++ b/backends/nvidia/tensorrt/converters/concat.py @@ -0,0 +1,553 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for Concatenation and Split Operations. + +This module provides converters for PyTorch tensor concatenation and splitting +operations to TensorRT layers. + +Supported operations: +- aten.cat.default: Concatenate tensors along an axis +- aten.stack.default: Stack tensors along a new axis +- aten.split.Tensor: Split tensor into chunks of given size +- aten.split_with_sizes.default: Split tensor into chunks with given sizes +- aten.chunk.default: Split tensor into specified number of chunks + +Notes: +- Concatenation uses network.add_concatenation() +- Split/chunk uses network.add_slice() for each output chunk +- Stack is implemented as unsqueeze + concatenation +""" + +import logging +from typing import Any, Dict, Optional, List, Tuple, Union + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import get_node_shape + +logger: logging.Logger = logging.getLogger(__name__) + + +def _get_positive_dim(dim: int, ndim: int) -> int: + """ + Convert a potentially negative dimension index to positive. + + Args: + dim: Dimension index (can be negative). + ndim: Number of dimensions. + + Returns: + Positive dimension index. + """ + if dim < 0: + dim = ndim + dim + return dim + + +def validate_cat(node: torch.fx.Node) -> bool: + """ + Validate that a cat node can be converted to TensorRT. + + Args: + node: FX node representing the cat operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug(f"[TensorRT] validate_cat: node {node.name} is not call_function") + return False + + args = node.args + # Args: tensors (list), dim (optional, default 0) + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_cat: node {node.name} has insufficient args" + ) + return False + + tensors = args[0] + if not isinstance(tensors, (list, tuple)) or len(tensors) < 1: + logger.debug( + f"[TensorRT] validate_cat: node {node.name} has invalid tensors arg" + ) + return False + + return True + + +def validate_stack(node: torch.fx.Node) -> bool: + """ + Validate that a stack node can be converted to TensorRT. + + Args: + node: FX node representing the stack operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_stack: node {node.name} is not call_function" + ) + return False + + args = node.args + # Args: tensors (list), dim (optional, default 0) + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_stack: node {node.name} has insufficient args" + ) + return False + + tensors = args[0] + if not isinstance(tensors, (list, tuple)) or len(tensors) < 1: + logger.debug( + f"[TensorRT] validate_stack: node {node.name} has invalid tensors arg" + ) + return False + + return True + + +def validate_split(node: torch.fx.Node) -> bool: + """ + Validate that a split node can be converted to TensorRT. + + Args: + node: FX node representing the split operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_split: node {node.name} is not call_function" + ) + return False + + args = node.args + # Args: input, split_size_or_sections, dim (optional, default 0) + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_split: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_chunk(node: torch.fx.Node) -> bool: + """ + Validate that a chunk node can be converted to TensorRT. + + Args: + node: FX node representing the chunk operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_chunk: node {node.name} is not call_function" + ) + return False + + args = node.args + # Args: input, chunks, dim (optional, default 0) + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_chunk: node {node.name} has insufficient args" + ) + return False + + return True + + +@converter("aten.cat.default", validator_fn=validate_cat) +def convert_cat( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch cat to TensorRT concatenation layer. + + Args: + node: FX node representing the cat operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_cat") from e + + logger.debug(f"[TensorRT] Converting cat node: {node.name}") + + args = node.args + + tensors = args[0] + cat_dim = args[1] if len(args) > 1 else 0 + + if not isinstance(tensors, (list, tuple)): + raise ValueError(f"tensors must be list or tuple, got {type(tensors)}") + + # Convert all input nodes to TensorRT tensors + trt_tensors = [] + for tensor_node in tensors: + if not isinstance(tensor_node, torch.fx.Node): + raise ValueError(f"Input must be node, got {type(tensor_node)}") + if tensor_node not in input_map: + raise ValueError(f"Input node {tensor_node.name} not found in input_map") + trt_tensors.append(input_map[tensor_node]) + + if len(trt_tensors) == 0: + raise ValueError("cat requires at least one input tensor") + + # Get number of dimensions from first input + ndim = len(trt_tensors[0].shape) + cat_dim = _get_positive_dim(cat_dim, ndim) + + # Create concatenation layer + layer = network.add_concatenation(trt_tensors) + if layer is None: + raise RuntimeError(f"Failed to create concatenation layer for cat {node.name}") + + layer.axis = cat_dim + layer.name = f"cat_{node.name}" + + output = layer.get_output(0) + + logger.debug( + f"[TensorRT] Created cat layer: {layer.name}, " + f"axis={cat_dim}, num_inputs={len(trt_tensors)}, output_shape={list(output.shape)}" + ) + + return output + + +@converter("aten.stack.default", validator_fn=validate_stack) +def convert_stack( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch stack to TensorRT. + + Stack is implemented as unsqueeze on each tensor followed by concatenation. + + Args: + node: FX node representing the stack operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_stack") from e + + logger.debug(f"[TensorRT] Converting stack node: {node.name}") + + args = node.args + + tensors = args[0] + stack_dim = args[1] if len(args) > 1 else 0 + + if not isinstance(tensors, (list, tuple)): + raise ValueError(f"tensors must be list or tuple, got {type(tensors)}") + + # Convert all input nodes to TensorRT tensors + trt_tensors = [] + for tensor_node in tensors: + if not isinstance(tensor_node, torch.fx.Node): + raise ValueError(f"Input must be node, got {type(tensor_node)}") + if tensor_node not in input_map: + raise ValueError(f"Input node {tensor_node.name} not found in input_map") + trt_tensors.append(input_map[tensor_node]) + + if len(trt_tensors) == 0: + raise ValueError("stack requires at least one input tensor") + + # Get number of dimensions from first input (output will have ndim + 1) + ndim = len(trt_tensors[0].shape) + stack_dim = _get_positive_dim(stack_dim, ndim + 1) + + # Unsqueeze each tensor at the stack dimension + unsqueezed_tensors = [] + for i, trt_tensor in enumerate(trt_tensors): + input_shape = list(trt_tensor.shape) + # Build output shape with new dimension of size 1 + output_shape = input_shape[:stack_dim] + [1] + input_shape[stack_dim:] + + shuffle_layer = network.add_shuffle(trt_tensor) + if shuffle_layer is None: + raise RuntimeError( + f"Failed to create shuffle layer for stack unsqueeze {node.name}" + ) + shuffle_layer.reshape_dims = trt.Dims(output_shape) + shuffle_layer.name = f"stack_unsqueeze_{node.name}_{i}" + unsqueezed_tensors.append(shuffle_layer.get_output(0)) + + # Create concatenation layer on the new dimension + layer = network.add_concatenation(unsqueezed_tensors) + if layer is None: + raise RuntimeError( + f"Failed to create concatenation layer for stack {node.name}" + ) + + layer.axis = stack_dim + layer.name = f"stack_{node.name}" + + output = layer.get_output(0) + + logger.debug( + f"[TensorRT] Created stack layer: {layer.name}, " + f"dim={stack_dim}, num_inputs={len(trt_tensors)}, output_shape={list(output.shape)}" + ) + + return output + + +@converter("aten.split.Tensor", "aten.split_with_sizes.default", "aten.split_with_sizes_copy.default", validator_fn=validate_split) +def convert_split( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> List[Any]: # List[trt.ITensor] + """ + Convert PyTorch split to TensorRT slice layers. + + Returns a list of output tensors (one for each split chunk). + + Args: + node: FX node representing the split operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + List of TensorRT output tensors. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_split") from e + + logger.debug(f"[TensorRT] Converting split node: {node.name}") + + args = node.args + + input_node = args[0] + split_size_or_sections = args[1] + split_dim = args[2] if len(args) > 2 else 0 + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input must be node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get shape from node metadata for reliability (TRT shapes can be invalid during network building) + input_shape = list(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Normalize split dimension + split_dim = _get_positive_dim(split_dim, ndim) + + # Get dimension size + dim_size = input_shape[split_dim] + if dim_size <= 0: + raise ValueError( + f"Cannot split dynamic dimension {split_dim} with size {dim_size}" + ) + + # Determine split sizes + if isinstance(split_size_or_sections, int): + # split_size: split into chunks of this size + split_sizes = [] + remaining = dim_size + while remaining > 0: + chunk_size = min(split_size_or_sections, remaining) + split_sizes.append(chunk_size) + remaining -= chunk_size + elif isinstance(split_size_or_sections, (list, tuple)): + # split_with_sizes: list of sizes + split_sizes = list(split_size_or_sections) + if sum(split_sizes) != dim_size: + raise ValueError( + f"split_with_sizes: sum of sizes {sum(split_sizes)} != dim_size {dim_size}" + ) + else: + raise ValueError( + f"split_size_or_sections must be int or list, got {type(split_size_or_sections)}" + ) + + # Create slice layers for each chunk + outputs = [] + start_idx = 0 + + for i, chunk_size in enumerate(split_sizes): + # Build start, shape, stride tuples for slice + start = [0] * ndim + shape = list(input_shape) + stride = [1] * ndim + + start[split_dim] = start_idx + shape[split_dim] = chunk_size + + # Create slice layer + layer = network.add_slice( + input_trt, trt.Dims(start), trt.Dims(shape), trt.Dims(stride) + ) + + if layer is None: + raise RuntimeError( + f"Failed to create slice layer for split {node.name} chunk {i}" + ) + + layer.name = f"split_{node.name}_{i}" + outputs.append(layer.get_output(0)) + + start_idx += chunk_size + + logger.debug( + f"[TensorRT] Created {len(outputs)} slice layers for split, " + f"dim={split_dim}, sizes={split_sizes}" + ) + + return outputs + + +@converter("aten.chunk.default", validator_fn=validate_chunk) +def convert_chunk( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] +) -> List[Any]: # List[trt.ITensor] + """ + Convert PyTorch chunk to TensorRT slice layers. + + Returns a list of output tensors (one for each chunk). + + Args: + node: FX node representing the chunk operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + List of TensorRT output tensors. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_chunk") from e + + logger.debug(f"[TensorRT] Converting chunk node: {node.name}") + + args = node.args + + input_node = args[0] + num_chunks = args[1] + chunk_dim = args[2] if len(args) > 2 else 0 + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input must be node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get shape from node metadata for reliability (TRT shapes can be invalid during network building) + input_shape = list(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Normalize chunk dimension + chunk_dim = _get_positive_dim(chunk_dim, ndim) + + # Get dimension size + dim_size = input_shape[chunk_dim] + if dim_size <= 0: + raise ValueError( + f"Cannot chunk dynamic dimension {chunk_dim} with size {dim_size}" + ) + + # Calculate chunk sizes (last chunk may be smaller) + base_chunk_size = dim_size // num_chunks + remainder = dim_size % num_chunks + + # Create slice layers for each chunk + outputs = [] + start_idx = 0 + + for i in range(num_chunks): + # Chunks are as equal as possible, with earlier chunks getting +1 if there's a remainder + if i < remainder: + chunk_size = base_chunk_size + 1 + else: + chunk_size = base_chunk_size + + if chunk_size == 0: + # No more chunks (num_chunks > dim_size) + break + + # Build start, shape, stride tuples for slice + start = [0] * ndim + shape = list(input_shape) + stride = [1] * ndim + + start[chunk_dim] = start_idx + shape[chunk_dim] = chunk_size + + # Create slice layer + layer = network.add_slice( + input_trt, trt.Dims(start), trt.Dims(shape), trt.Dims(stride) + ) + + if layer is None: + raise RuntimeError( + f"Failed to create slice layer for chunk {node.name} chunk {i}" + ) + + layer.name = f"chunk_{node.name}_{i}" + outputs.append(layer.get_output(0)) + + start_idx += chunk_size + + logger.debug( + f"[TensorRT] Created {len(outputs)} slice layers for chunk, " + f"dim={chunk_dim}, num_chunks={num_chunks}" + ) + + return outputs + + +__all__ = [ + "convert_cat", + "convert_stack", + "convert_split", + "convert_chunk", + "validate_cat", + "validate_stack", + "validate_split", + "validate_chunk", +] diff --git a/backends/nvidia/tensorrt/converters/conv2d.py b/backends/nvidia/tensorrt/converters/conv2d.py new file mode 100644 index 00000000000..9a4e8c74434 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/conv2d.py @@ -0,0 +1,314 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""TensorRT Converter for Conv2d Operations.""" + +import logging +from typing import Any, Dict, Optional, Union + +import torch + +from executorch.backends.nvidia.tensorrt.converter_registry import converter + +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_buffer, + is_lifted_tensor_constant, + is_param, +) +from torch.export.exported_program import ExportedProgram + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_conv2d(node: torch.fx.Node) -> bool: + """Validate that a conv2d node can be converted to TensorRT.""" + if node.op != "call_function": + return False + if len(node.args) < 2: + return False + return True + + +def validate_convolution(node: torch.fx.Node) -> bool: + """Validate that a convolution node can be converted to TensorRT.""" + if node.op != "call_function": + return False + if len(node.args) < 9: + return False + # Transposed convolution not supported + if node.args[6]: + return False + return True + + +def _get_param_tensor( + exp_prog: Optional[ExportedProgram], + node: Any, +) -> Optional[torch.Tensor]: + """Extract a constant tensor from an ExportedProgram.""" + if node is None: + return None + if isinstance(node, torch.Tensor): + return node + if not isinstance(node, torch.fx.Node): + return None + + if exp_prog is not None: + if is_param(exp_prog, node): + return get_param(exp_prog, node) + elif is_buffer(exp_prog, node): + return get_buffer(exp_prog, node) + elif is_lifted_tensor_constant(exp_prog, node): + return get_lifted_tensor_constant(exp_prog, node) + + # Fallback for get_attr nodes + if isinstance(node, torch.fx.Node) and node.op == "get_attr": + if exp_prog is not None: + try: + target = node.target + if isinstance(target, str): + return getattr(exp_prog.graph_module, target) + except AttributeError: + pass + try: + if hasattr(node, "graph") and hasattr(node.graph, "owning_module"): + target = node.target + if isinstance(target, str): + return getattr(node.graph.owning_module, target) + except AttributeError: + pass + + return None + + +@converter("aten.conv2d.default", validator_fn=validate_conv2d, needs_edge_program=True) +def convert_conv2d( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Union[ExportedProgram, torch.fx.GraphModule]] = None, + ctx: Any = None, +) -> Any: + """Convert PyTorch conv2d operation to TensorRT convolution layer.""" + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT is required for convert_conv2d.") from e + + args = node.args + kwargs = node.kwargs + + input_node = args[0] + weight_node = args[1] + bias_node = args[2] if len(args) > 2 else kwargs.get("bias", None) + stride = args[3] if len(args) > 3 else kwargs.get("stride", [1, 1]) + padding = args[4] if len(args) > 4 else kwargs.get("padding", [0, 0]) + dilation = args[5] if len(args) > 5 else kwargs.get("dilation", [1, 1]) + groups = args[6] if len(args) > 6 else kwargs.get("groups", 1) + + if not isinstance(input_node, torch.fx.Node) or input_node not in input_map: + raise ValueError(f"Input node {input_node} not found in input_map") + + input_trt = input_map[input_node] + + exp_prog = edge_program if isinstance(edge_program, ExportedProgram) else None + weight_tensor = _get_param_tensor(exp_prog, weight_node) + if weight_tensor is None: + raise ValueError(f"Could not extract weight tensor for conv2d node {node.name}") + + weight_np = np.ascontiguousarray( + weight_tensor.detach().cpu().numpy().astype(np.float32) + ) + out_channels = weight_np.shape[0] + kernel_h = weight_np.shape[2] + kernel_w = weight_np.shape[3] + + # Store weight to prevent GC before engine build completes + if not hasattr(convert_conv2d, '_weight_storage'): + convert_conv2d._weight_storage = [] + convert_conv2d._weight_storage.append(weight_np) + + layer = network.add_convolution_nd( + input_trt, + out_channels, + trt.Dims([kernel_h, kernel_w]), + trt.Weights(weight_np), + ) + + if layer is None: + raise RuntimeError(f"Failed to create TensorRT convolution layer for {node.name}") + + layer.stride_nd = trt.Dims(list(stride) if hasattr(stride, "__iter__") else [stride, stride]) + layer.padding_nd = trt.Dims(list(padding) if hasattr(padding, "__iter__") else [padding, padding]) + layer.dilation_nd = trt.Dims(list(dilation) if hasattr(dilation, "__iter__") else [dilation, dilation]) + layer.num_groups = groups + + if bias_node is not None: + bias_tensor = _get_param_tensor(exp_prog, bias_node) + if bias_tensor is not None: + bias_np = np.ascontiguousarray( + bias_tensor.detach().cpu().numpy().astype(np.float32) + ) + convert_conv2d._weight_storage.append(bias_np) + layer.bias = trt.Weights(bias_np) + + layer.name = f"conv2d_{node.name}" + return layer.get_output(0) + + +@converter( + "aten.convolution.default", validator_fn=validate_convolution, needs_edge_program=True +) +def convert_convolution( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Union[ExportedProgram, torch.fx.GraphModule]] = None, + ctx: Any = None, +) -> Any: + """Convert PyTorch convolution operation to TensorRT convolution layer.""" + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT is required for convert_convolution.") from e + + args = node.args + input_node = args[0] + weight_node = args[1] + bias_node = args[2] + stride = args[3] + padding = args[4] + dilation = args[5] + transposed = args[6] + groups = args[8] + + if transposed: + raise ValueError(f"Transposed convolution not supported for node {node.name}") + + if not isinstance(input_node, torch.fx.Node) or input_node not in input_map: + raise ValueError(f"Input node {input_node} not found in input_map") + + input_trt = input_map[input_node] + + exp_prog = edge_program if isinstance(edge_program, ExportedProgram) else None + weight_tensor = _get_param_tensor(exp_prog, weight_node) + if weight_tensor is None: + raise ValueError(f"Could not extract weight tensor for convolution node {node.name}") + + weight_np = weight_tensor.detach().cpu().numpy().astype(np.float32) + out_channels = weight_np.shape[0] + + if not weight_np.flags['C_CONTIGUOUS']: + weight_np = np.ascontiguousarray(weight_np) + + is_conv1d = len(weight_np.shape) == 3 + + if is_conv1d: + kernel_size = weight_np.shape[2] + input_shape = input_trt.shape + if len(input_shape) == 3: + shuffle_in = network.add_shuffle(input_trt) + shuffle_in.reshape_dims = trt.Dims([input_shape[0], input_shape[1], 1, input_shape[2]]) + shuffle_in.name = f"conv1d_unsqueeze_{node.name}" + input_trt = shuffle_in.get_output(0) + + weight_4d = np.ascontiguousarray( + weight_np.reshape(out_channels, weight_np.shape[1], 1, kernel_size) + ) + + # Store weight to prevent GC before engine build completes + if not hasattr(convert_convolution, '_weight_storage'): + convert_convolution._weight_storage = [] + convert_convolution._weight_storage.append(weight_4d) + + layer = network.add_convolution_nd( + input_trt, + out_channels, + trt.Dims([1, kernel_size]), + trt.Weights(weight_4d), + ) + layer.stride_nd = trt.Dims([1, stride[0]]) + layer.padding_nd = trt.Dims([0, padding[0]]) + layer.dilation_nd = trt.Dims([1, dilation[0]]) + layer.num_groups = groups + + if bias_node is not None: + bias_tensor = _get_param_tensor(exp_prog, bias_node) + if bias_tensor is not None: + bias_np = np.ascontiguousarray( + bias_tensor.detach().cpu().numpy().astype(np.float32) + ) + # Store bias to prevent GC before engine build completes + convert_convolution._weight_storage.append(bias_np) + layer.bias = trt.Weights(bias_np) + + layer.name = f"conv1d_{node.name}" + output = layer.get_output(0) + + output_shape = output.shape + if len(output_shape) == 4 and output_shape[2] == 1: + shuffle_out = network.add_shuffle(output) + shuffle_out.reshape_dims = trt.Dims([output_shape[0], output_shape[1], output_shape[3]]) + shuffle_out.name = f"conv1d_squeeze_{node.name}" + output = shuffle_out.get_output(0) + else: + kernel_h = weight_np.shape[2] + kernel_w = weight_np.shape[3] + + # Store weight to prevent GC before engine build completes + weight_np_contiguous = np.ascontiguousarray(weight_np) + if not hasattr(convert_convolution, '_weight_storage'): + convert_convolution._weight_storage = [] + convert_convolution._weight_storage.append(weight_np_contiguous) + + layer = network.add_convolution_nd( + input_trt, + out_channels, + trt.Dims([kernel_h, kernel_w]), + trt.Weights(weight_np_contiguous), + ) + layer.stride_nd = trt.Dims(list(stride)) + layer.padding_nd = trt.Dims(list(padding)) + layer.dilation_nd = trt.Dims(list(dilation)) + layer.num_groups = groups + + if bias_node is not None: + bias_tensor = _get_param_tensor(exp_prog, bias_node) + if bias_tensor is not None: + bias_np_contiguous = np.ascontiguousarray( + bias_tensor.detach().cpu().numpy().astype(np.float32) + ) + layer.bias = trt.Weights(bias_np_contiguous) + convert_convolution._weight_storage.append(bias_np_contiguous) + + layer.name = f"convolution_{node.name}" + output = layer.get_output(0) + + return output + + +def clear_weight_storage() -> None: + """Clear weight storage to free memory after engine build.""" + if hasattr(convert_convolution, '_weight_storage'): + convert_convolution._weight_storage.clear() + if hasattr(convert_conv2d, '_weight_storage'): + convert_conv2d._weight_storage.clear() + + +__all__ = [ + "clear_weight_storage", + "convert_conv2d", + "convert_convolution", + "validate_conv2d", + "validate_convolution", +] diff --git a/backends/nvidia/tensorrt/converters/dim_order_ops.py b/backends/nvidia/tensorrt/converters/dim_order_ops.py new file mode 100644 index 00000000000..7671d8f96ac --- /dev/null +++ b/backends/nvidia/tensorrt/converters/dim_order_ops.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for ExecuTorch Dimension Order Operations. + +This module provides converters for ExecuTorch edge-specific dimension order +operations. These operations handle memory layout conversions (e.g., contiguous +to channels-last) and are inserted during edge transforms. + +For TensorRT, these are treated as identity/pass-through operations since +TensorRT handles memory layout internally. + +Supported operations: +- dim_order_ops._to_dim_order_copy.default: Memory layout conversion +- dim_order_ops._clone_dim_order.default: Clone with dimension order +- aten.clone.default: Clone operation + +Notes: +- These operations don't change tensor values, only memory layout +- TensorRT manages memory layout internally, so we pass through the input +""" + +import logging +from typing import Any, Dict, Optional + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import get_trt_tensor_from_node + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_dim_order_copy(node: torch.fx.Node) -> bool: + """ + Validate that a dim_order_copy node can be converted. + + Args: + node: FX node representing the operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_dim_order_copy: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_clone(node: torch.fx.Node) -> bool: + """ + Validate that a clone node can be converted. + + Args: + node: FX node representing the clone operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + if len(args) < 1: + logger.debug( + f"[TensorRT] validate_clone: node {node.name} has insufficient args" + ) + return False + + return True + + +@converter( + "dim_order_ops._to_dim_order_copy.default", + "exir_ops.edge.dim_order_ops._to_dim_order_copy.default", + validator_fn=validate_dim_order_copy, +) +def convert_to_dim_order_copy( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert dim_order_ops._to_dim_order_copy to TensorRT identity layer. + + This operation converts tensor memory layout (e.g., contiguous to channels-last). + For TensorRT, we treat this as a pass-through since TensorRT manages layout + internally. + + Args: + node: FX node representing the dim_order_copy operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor (same as input). + """ + logger.debug(f"[TensorRT] Converting dim_order_copy node: {node.name}") + + args = node.args + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError( + f"Input to dim_order_copy must be a node, got {type(input_node)}" + ) + + input_trt = get_trt_tensor_from_node(network, input_node, input_map, node.name) + + # For TensorRT, memory layout is handled internally. + # We use an identity layer to pass through the tensor. + layer = network.add_identity(input_trt) + if layer is None: + raise RuntimeError( + f"Failed to create identity layer for dim_order_copy {node.name}" + ) + + layer.name = f"dim_order_copy_{node.name}" + logger.debug(f"[TensorRT] Created identity layer for dim_order_copy: {layer.name}") + + return layer.get_output(0) + + +@converter( + "dim_order_ops._clone_dim_order.default", + "exir_ops.edge.dim_order_ops._clone_dim_order.default", + validator_fn=validate_dim_order_copy, +) +def convert_clone_dim_order( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert dim_order_ops._clone_dim_order to TensorRT identity layer. + + This operation clones a tensor with a specific dimension order. + For TensorRT, we treat this as a pass-through. + + Args: + node: FX node representing the clone_dim_order operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor (same as input). + """ + logger.debug(f"[TensorRT] Converting clone_dim_order node: {node.name}") + + args = node.args + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError( + f"Input to clone_dim_order must be a node, got {type(input_node)}" + ) + + input_trt = get_trt_tensor_from_node(network, input_node, input_map, node.name) + + layer = network.add_identity(input_trt) + if layer is None: + raise RuntimeError( + f"Failed to create identity layer for clone_dim_order {node.name}" + ) + + layer.name = f"clone_dim_order_{node.name}" + logger.debug(f"[TensorRT] Created identity layer for clone_dim_order: {layer.name}") + + return layer.get_output(0) + + +@converter("aten.clone.default", validator_fn=validate_clone) +def convert_clone( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert aten.clone to TensorRT identity layer. + + Clone creates a copy of a tensor. For TensorRT, we use an identity layer + since TensorRT manages memory internally. + + Args: + node: FX node representing the clone operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + logger.debug(f"[TensorRT] Converting clone node: {node.name}") + + args = node.args + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to clone must be a node, got {type(input_node)}") + + input_trt = get_trt_tensor_from_node(network, input_node, input_map, node.name) + + layer = network.add_identity(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create identity layer for clone {node.name}") + + layer.name = f"clone_{node.name}" + logger.debug(f"[TensorRT] Created identity layer for clone: {layer.name}") + + return layer.get_output(0) + + +__all__ = [ + "convert_to_dim_order_copy", + "convert_clone_dim_order", + "convert_clone", + "validate_dim_order_copy", + "validate_clone", +] diff --git a/backends/nvidia/tensorrt/converters/getitem.py b/backends/nvidia/tensorrt/converters/getitem.py new file mode 100644 index 00000000000..a76204f32c8 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/getitem.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converter for getitem Operations. + +This module provides converters for Python getitem operations, which are used +to extract elements from tuples/lists (e.g., extracting the first output from +batch_norm which returns multiple values). + +Supported operations: +- _operator.getitem +- operator.getitem +- getitem +""" + +import logging +from typing import Any, Dict, Optional + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_getitem(node: torch.fx.Node) -> bool: + """ + Validate that a getitem node can be converted to TensorRT. + + Args: + node: FX node representing the getitem operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + # getitem takes 2 args: container and index + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_getitem: node {node.name} has insufficient args" + ) + return False + + return True + + +@converter( + "_operator.getitem", + "operator.getitem", + "getitem", + validator_fn=validate_getitem, +) +def convert_getitem( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """ + Convert Python getitem operation to pass through the correct tensor. + + This is used when operations like batch_norm return multiple values + (output, mean, var) and we need to extract just one of them. + + Args: + node: FX node representing the getitem operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: ExportedProgram for extracting weights. + + Returns: + TensorRT output tensor (the extracted item). + """ + logger.debug(f"[TensorRT] Converting getitem node: {node.name}") + + args = node.args + container = args[0] + index = args[1] + + if not isinstance(container, torch.fx.Node): + raise ValueError(f"Container to getitem must be a node, got {type(container)}") + + if container not in input_map: + raise ValueError(f"Container node {container.name} not found in input_map") + + # The container should already be mapped to a TensorRT tensor + # (For batch_norm, we already return just the first output) + result = input_map[container] + + # If the result is a tuple/list (which it shouldn't be in TensorRT), + # extract the indexed element + if isinstance(result, (list, tuple)): + result = result[index] + + logger.debug( + f"[TensorRT] getitem: extracting index {index} from {container.name}" + ) + + return result + + +__all__ = ["convert_getitem", "validate_getitem"] diff --git a/backends/nvidia/tensorrt/converters/linear.py b/backends/nvidia/tensorrt/converters/linear.py new file mode 100644 index 00000000000..89896535f94 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/linear.py @@ -0,0 +1,255 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converter for Linear (fully connected) Operations. + +This module provides converters for PyTorch linear operations to TensorRT +fully connected layers. + +Supported operations: +- aten.linear.default: Fully connected layer (y = x @ weight.T + bias) + +The linear operation is implemented as: +1. Matrix multiplication: x @ weight.T +2. Add bias (if present) +""" + +import logging +from typing import Any, Dict, Optional, Union + +import torch + +from executorch.backends.nvidia.tensorrt.converter_registry import converter + +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_buffer, + is_lifted_tensor_constant, + is_param, +) +from torch.export.exported_program import ExportedProgram + +logger: logging.Logger = logging.getLogger(__name__) + + +def _is_get_attr_node(node: Any) -> bool: + """Check if node is a get_attr node.""" + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + +def _get_param_tensor( + exp_prog: Optional[ExportedProgram], + node: Any, +) -> Optional[torch.Tensor]: + """Extract a constant tensor from an ExportedProgram.""" + if node is None: + return None + + if isinstance(node, torch.Tensor): + return node + + if not isinstance(node, torch.fx.Node): + return None + + if exp_prog is not None: + if is_param(exp_prog, node): + return get_param(exp_prog, node) + elif is_buffer(exp_prog, node): + return get_buffer(exp_prog, node) + elif is_lifted_tensor_constant(exp_prog, node): + return get_lifted_tensor_constant(exp_prog, node) + + if _is_get_attr_node(node): + if exp_prog is not None: + try: + target = node.target + if isinstance(target, str): + return getattr(exp_prog.graph_module, target) + except AttributeError: + pass + try: + if hasattr(node, "graph") and hasattr(node.graph, "owning_module"): + target = node.target + if isinstance(target, str): + return getattr(node.graph.owning_module, target) + except AttributeError: + pass + + return None + + +def validate_linear(node: torch.fx.Node) -> bool: + """Validate that a linear node can be converted to TensorRT.""" + if node.op != "call_function": + return False + + args = node.args + # linear requires at least input and weight + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_linear: node {node.name} has insufficient args" + ) + return False + + return True + + +@converter("aten.linear.default", validator_fn=validate_linear, needs_edge_program=True) +def convert_linear( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Union[ExportedProgram, torch.fx.GraphModule]] = None, + ctx: Any = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch linear operation to TensorRT layers. + + Linear is defined as: y = x @ weight.T + bias + + Args: + node: FX node representing the linear operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: ExportedProgram for extracting weights. + + Returns: + TensorRT output tensor. + + Raises: + ImportError: If TensorRT is not installed. + ValueError: If required inputs are missing. + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError( + "TensorRT is required for convert_linear" + ) from e + + logger.debug(f"[TensorRT] Converting linear node: {node.name}") + + args = node.args + kwargs = node.kwargs + + # Extract arguments + # linear(input, weight, bias=None) + input_node = args[0] + weight_node = args[1] + bias_node = args[2] if len(args) > 2 else kwargs.get("bias", None) + + # Validate input + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to linear must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get weight tensor + exp_prog = edge_program if isinstance(edge_program, ExportedProgram) else None + weight_tensor = _get_param_tensor(exp_prog, weight_node) + if weight_tensor is None: + raise ValueError( + f"Could not extract weight tensor for linear node {node.name}. " + "Weight must be a constant tensor." + ) + + # Weight shape for linear: (out_features, in_features) + weight_np = weight_tensor.detach().cpu().numpy().astype(np.float32) + out_features = weight_np.shape[0] + in_features = weight_np.shape[1] + + logger.debug( + f"[TensorRT] linear: in_features={in_features}, out_features={out_features}" + ) + + # Create weight as constant tensor with transposed shape for matmul + # Linear: y = x @ weight.T, so we need weight.T for matmul + weight_transposed = np.ascontiguousarray(weight_np.T) # Shape: (in_features, out_features) + + # Store weight to prevent GC before engine build completes + if not hasattr(convert_linear, '_weight_storage'): + convert_linear._weight_storage = [] + convert_linear._weight_storage.append(weight_transposed) + + weight_const = network.add_constant( + trt.Dims(weight_transposed.shape), + trt.Weights(weight_transposed) + ) + if weight_const is None: + raise RuntimeError(f"Failed to create weight constant for linear {node.name}") + weight_const.name = f"linear_weight_{node.name}" + weight_trt = weight_const.get_output(0) + + # Matrix multiplication: input @ weight.T + mm_layer = network.add_matrix_multiply( + input_trt, trt.MatrixOperation.NONE, + weight_trt, trt.MatrixOperation.NONE + ) + if mm_layer is None: + raise RuntimeError(f"Failed to create matmul layer for linear {node.name}") + mm_layer.name = f"linear_mm_{node.name}" + output = mm_layer.get_output(0) + + # Add bias if present + if bias_node is not None: + bias_tensor = _get_param_tensor(exp_prog, bias_node) + if bias_tensor is not None: + bias_np = bias_tensor.detach().cpu().numpy().astype(np.float32) + + # Reshape bias for broadcasting + # For 2D input [batch, out_features], bias is [out_features] + # Need to reshape to [1, out_features] for proper broadcasting + output_dims = len(output.shape) + bias_shape = [1] * (output_dims - 1) + [out_features] + bias_reshaped = bias_np.reshape(bias_shape) + + # Store bias to prevent GC before engine build completes + convert_linear._weight_storage.append(bias_reshaped) + + bias_const = network.add_constant( + trt.Dims(bias_reshaped.shape), + trt.Weights(bias_reshaped) + ) + if bias_const is None: + raise RuntimeError( + f"Failed to create bias constant for linear {node.name}" + ) + bias_const.name = f"linear_bias_const_{node.name}" + bias_trt = bias_const.get_output(0) + + add_layer = network.add_elementwise( + output, bias_trt, trt.ElementWiseOperation.SUM + ) + if add_layer is None: + raise RuntimeError( + f"Failed to create bias add layer for linear {node.name}" + ) + add_layer.name = f"linear_bias_{node.name}" + output = add_layer.get_output(0) + + logger.debug(f"[TensorRT] Added bias to linear layer {node.name}") + + logger.debug(f"[TensorRT] Created linear layer: {node.name}") + + return output + + +def clear_weight_storage() -> None: + """Clear weight storage to free memory after engine build.""" + if hasattr(convert_linear, '_weight_storage'): + convert_linear._weight_storage.clear() + + +__all__ = ["clear_weight_storage", "convert_linear", "validate_linear"] diff --git a/backends/nvidia/tensorrt/converters/pooling.py b/backends/nvidia/tensorrt/converters/pooling.py new file mode 100644 index 00000000000..a7e5f21db3e --- /dev/null +++ b/backends/nvidia/tensorrt/converters/pooling.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for Pooling Operations. + +This module provides converters for PyTorch pooling operations to TensorRT +pooling layers. + +Supported operations: +- aten.avg_pool2d.default: 2D average pooling +- aten.max_pool2d.default: 2D max pooling (when indices are not used) +- aten.max_pool2d_with_indices.default: 2D max pooling with indices output +- aten.adaptive_avg_pool2d.default: Adaptive 2D average pooling (for SE blocks) + +Notes: +- TensorRT doesn't support dilation != 1 for pooling +- TensorRT doesn't support divisor_override for avg_pool +- Adaptive pooling requires static spatial dimensions +- max_pool_with_indices: indices output is NOT supported, only values are returned +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_avg_pool2d(node: torch.fx.Node) -> bool: + """ + Validate that an avg_pool2d node can be converted to TensorRT. + + Args: + node: FX node representing the avg_pool2d operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_avg_pool2d: node {node.name} is not call_function" + ) + return False + + args = node.args + # Minimum args: input, kernel_size + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_avg_pool2d: node {node.name} has insufficient args" + ) + return False + + # Check for divisor_override - not supported by TensorRT + if len(args) > 6 and args[6] is not None: + logger.debug( + f"[TensorRT] validate_avg_pool2d: divisor_override not supported" + ) + return False + + return True + + +def validate_max_pool2d(node: torch.fx.Node) -> bool: + """ + Validate that a max_pool2d node can be converted to TensorRT. + + Args: + node: FX node representing the max_pool2d operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_max_pool2d: node {node.name} is not call_function" + ) + return False + + args = node.args + # Minimum args: input, kernel_size + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_max_pool2d: node {node.name} has insufficient args" + ) + return False + + # Check for dilation - only dilation=1 is supported + if len(args) > 4: + dilation = args[4] + if dilation is not None: + if isinstance(dilation, (list, tuple)): + if any(d != 1 for d in dilation): + logger.debug( + f"[TensorRT] validate_max_pool2d: dilation != 1 not supported" + ) + return False + elif dilation != 1: + logger.debug( + f"[TensorRT] validate_max_pool2d: dilation != 1 not supported" + ) + return False + + return True + + +def validate_adaptive_avg_pool2d(node: torch.fx.Node) -> bool: + """ + Validate that an adaptive_avg_pool2d node can be converted to TensorRT. + + Args: + node: FX node representing the adaptive_avg_pool2d operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_adaptive_avg_pool2d: node {node.name} is not call_function" + ) + return False + + args = node.args + # Args: input, output_size + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_adaptive_avg_pool2d: node {node.name} has insufficient args" + ) + return False + + return True + + +def _extend_to_tuple(value: Any, length: int) -> Tuple[int, ...]: + """ + Extend a value to a tuple of given length. + + Args: + value: An int, tuple, or list. + length: Desired tuple length. + + Returns: + Tuple of integers with the specified length. + """ + if value is None: + return (0,) * length + if isinstance(value, int): + return (value,) * length + if isinstance(value, (list, tuple)): + if len(value) == length: + return tuple(value) + if len(value) == 1: + return (value[0],) * length + return tuple(value) + + +@converter("aten.avg_pool2d.default", validator_fn=validate_avg_pool2d) +def convert_avg_pool2d( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch avg_pool2d to TensorRT pooling layer. + + Args: + node: FX node representing the avg_pool2d operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + + Raises: + ImportError: If TensorRT is not installed. + ValueError: If required inputs are missing. + RuntimeError: If TensorRT layer creation fails. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError( + "TensorRT is required for convert_avg_pool2d" + ) from e + + logger.debug(f"[TensorRT] Converting avg_pool2d node: {node.name}") + + args = node.args + kwargs = node.kwargs + + # Extract arguments + # avg_pool2d(input, kernel_size, stride=[], padding=0, ceil_mode=False, + # count_include_pad=True, divisor_override=None) + input_node = args[0] + kernel_size = args[1] if len(args) > 1 else kwargs.get("kernel_size") + stride = args[2] if len(args) > 2 else kwargs.get("stride", []) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + count_include_pad = args[5] if len(args) > 5 else kwargs.get("count_include_pad", True) + + # Validate input + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to avg_pool2d must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Normalize parameters to tuples + kernel_size = _extend_to_tuple(kernel_size, 2) + + # Default stride to kernel_size if empty or None + if stride is None or (isinstance(stride, (list, tuple)) and len(stride) == 0): + stride = kernel_size + else: + stride = _extend_to_tuple(stride, 2) + + padding = _extend_to_tuple(padding, 2) + + logger.debug( + f"[TensorRT] avg_pool2d parameters: kernel={kernel_size}, " + f"stride={stride}, padding={padding}, ceil_mode={ceil_mode}, " + f"count_include_pad={count_include_pad}" + ) + + # Create pooling layer using add_pooling_nd + layer = network.add_pooling_nd( + input=input_trt, + type=trt.PoolingType.AVERAGE, + window_size=trt.Dims(kernel_size), + ) + if layer is None: + raise RuntimeError( + f"Failed to create avg_pool2d layer for node {node.name}" + ) + + layer.stride_nd = trt.Dims(stride) + layer.padding_nd = trt.Dims(padding) + layer.name = f"avg_pool2d_{node.name}" + + # Handle count_include_pad + # TensorRT: average_count_excludes_padding = True means padding is NOT included + layer.average_count_excludes_padding = not count_include_pad + + # Handle ceil_mode + if ceil_mode: + layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP + + logger.debug(f"[TensorRT] Created avg_pool2d layer: {layer.name}") + + return layer.get_output(0) + + +@converter( + "aten.max_pool2d.default", + "aten.max_pool2d_with_indices.default", + validator_fn=validate_max_pool2d, +) +def convert_max_pool2d( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch max_pool2d to TensorRT pooling layer. + + Note: For max_pool2d_with_indices, only the values output is returned. + The indices are NOT supported by TensorRT. If indices are actually used + in the model, conversion will fail during graph execution. + + Args: + node: FX node representing the max_pool2d operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor (only values, no indices). + + Raises: + ImportError: If TensorRT is not installed. + ValueError: If required inputs are missing. + RuntimeError: If TensorRT layer creation fails or dilation != 1. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError( + "TensorRT is required for convert_max_pool2d" + ) from e + + logger.debug(f"[TensorRT] Converting max_pool2d node: {node.name}") + + args = node.args + kwargs = node.kwargs + + # Extract arguments + # max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) + input_node = args[0] + kernel_size = args[1] if len(args) > 1 else kwargs.get("kernel_size") + stride = args[2] if len(args) > 2 else kwargs.get("stride", []) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + dilation = args[4] if len(args) > 4 else kwargs.get("dilation", 1) + ceil_mode = args[5] if len(args) > 5 else kwargs.get("ceil_mode", False) + + # Validate input + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to max_pool2d must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Normalize parameters to tuples + kernel_size = _extend_to_tuple(kernel_size, 2) + + # Default stride to kernel_size if empty or None + if stride is None or (isinstance(stride, (list, tuple)) and len(stride) == 0): + stride = kernel_size + else: + stride = _extend_to_tuple(stride, 2) + + padding = _extend_to_tuple(padding, 2) + dilation = _extend_to_tuple(dilation, 2) + + # Validate dilation + if dilation != (1, 1): + raise RuntimeError( + f"TensorRT only supports dilation=(1, 1) for max_pool2d, got {dilation}" + ) + + logger.debug( + f"[TensorRT] max_pool2d parameters: kernel={kernel_size}, " + f"stride={stride}, padding={padding}, ceil_mode={ceil_mode}" + ) + + # Create pooling layer using add_pooling_nd + layer = network.add_pooling_nd( + input=input_trt, + type=trt.PoolingType.MAX, + window_size=trt.Dims(kernel_size), + ) + if layer is None: + raise RuntimeError( + f"Failed to create max_pool2d layer for node {node.name}" + ) + + layer.stride_nd = trt.Dims(stride) + layer.padding_nd = trt.Dims(padding) + layer.name = f"max_pool2d_{node.name}" + + # Handle ceil_mode + if ceil_mode: + layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP + + logger.debug(f"[TensorRT] Created max_pool2d layer: {layer.name}") + + return layer.get_output(0) + + +@converter("aten.adaptive_avg_pool2d.default", validator_fn=validate_adaptive_avg_pool2d) +def convert_adaptive_avg_pool2d( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch adaptive_avg_pool2d to TensorRT pooling layer. + + Adaptive pooling computes kernel_size and stride from input and output shapes: + - stride = input_size // output_size + - kernel_size = input_size - (output_size - 1) * stride + + This is critical for Squeeze-and-Excitation (SE) blocks in MobileNetV3. + + Limitations: + - Input spatial dimensions (H, W) must be static (not -1) + - Input size must be evenly divisible by output size + + Args: + node: FX node representing the adaptive_avg_pool2d operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + + Raises: + ImportError: If TensorRT is not installed. + ValueError: If required inputs are missing. + RuntimeError: If TensorRT layer creation fails or constraints not met. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError( + "TensorRT is required for convert_adaptive_avg_pool2d" + ) from e + + logger.debug(f"[TensorRT] Converting adaptive_avg_pool2d node: {node.name}") + + args = node.args + kwargs = node.kwargs + + # Extract arguments + # adaptive_avg_pool2d(input, output_size) + input_node = args[0] + output_size = args[1] if len(args) > 1 else kwargs.get("output_size") + + # Validate input + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to adaptive_avg_pool2d must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Normalize output_size to tuple + output_size = _extend_to_tuple(output_size, 2) + + # Get input shape (expecting NCHW format) + input_shape = input_trt.shape + if len(input_shape) != 4: + raise RuntimeError( + f"adaptive_avg_pool2d expects 4D input (NCHW), got shape {input_shape}" + ) + + # Extract spatial dimensions (H, W) + input_h = input_shape[2] + input_w = input_shape[3] + + # Check for dynamic dimensions + if input_h == -1 or input_w == -1: + raise RuntimeError( + f"adaptive_avg_pool2d doesn't support dynamic spatial dimensions. " + f"Input shape: {input_shape}. H and W must be static." + ) + + output_h, output_w = output_size + + # Validate divisibility + if input_h % output_h != 0: + raise RuntimeError( + f"Input height ({input_h}) must be divisible by output height ({output_h})" + ) + if input_w % output_w != 0: + raise RuntimeError( + f"Input width ({input_w}) must be divisible by output width ({output_w})" + ) + + # Calculate kernel_size and stride + # Formula: stride = input_size // output_size + # kernel = input_size - (output_size - 1) * stride + stride_h = input_h // output_h + stride_w = input_w // output_w + + kernel_h = input_h - (output_h - 1) * stride_h + kernel_w = input_w - (output_w - 1) * stride_w + + kernel_size = (kernel_h, kernel_w) + stride = (stride_h, stride_w) + + logger.debug( + f"[TensorRT] adaptive_avg_pool2d: output_size={output_size}, " + f"input_spatial=({input_h}, {input_w}), computed kernel={kernel_size}, stride={stride}" + ) + + # Create pooling layer + layer = network.add_pooling_nd( + input=input_trt, + type=trt.PoolingType.AVERAGE, + window_size=trt.Dims(kernel_size), + ) + if layer is None: + raise RuntimeError( + f"Failed to create adaptive_avg_pool2d layer for node {node.name}" + ) + + layer.stride_nd = trt.Dims(stride) + layer.name = f"adaptive_avg_pool2d_{node.name}" + + logger.debug(f"[TensorRT] Created adaptive_avg_pool2d layer: {layer.name}") + + return layer.get_output(0) + + +__all__ = [ + "convert_avg_pool2d", + "convert_max_pool2d", + "convert_adaptive_avg_pool2d", + "validate_avg_pool2d", + "validate_max_pool2d", + "validate_adaptive_avg_pool2d", +] diff --git a/backends/nvidia/tensorrt/converters/reduction.py b/backends/nvidia/tensorrt/converters/reduction.py new file mode 100644 index 00000000000..fd75187b653 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/reduction.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for Reduction Operations. + +This module provides converters for PyTorch reduction operations to TensorRT +reduction layers. + +Supported operations: +- aten.mean.dim: Reduce mean along specified dimensions +- aten.sum.dim_IntList: Reduce sum along specified dimensions + +Notes: +- TensorRT uses axes as a bitmask for specifying dimensions +- keepdim parameter controls whether reduced dimensions are kept +""" + +import logging +from typing import Any, Dict, List, Optional, Union + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import get_node_shape + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_mean(node: torch.fx.Node) -> bool: + """ + Validate that a mean node can be converted to TensorRT. + + Args: + node: FX node representing the mean operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_mean: node {node.name} is not call_function" + ) + return False + + args = node.args + # Minimum args: input, dim + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_mean: node {node.name} has insufficient args" + ) + return False + + return True + + +def _get_reduce_axes( + dims: Union[int, List[int]], + ndim: int, +) -> int: + """ + Convert dimension indices to TensorRT axes bitmask. + + Args: + dims: Dimension(s) to reduce. + ndim: Total number of dimensions. + + Returns: + TensorRT axes bitmask. + """ + if isinstance(dims, int): + dims = [dims] + + axes = 0 + for dim in dims: + # Handle negative dimensions + if dim < 0: + dim = ndim + dim + axes |= 1 << dim + + return axes + + +@converter( + "aten.mean.dim", + "aten.mean.default", + validator_fn=validate_mean, +) +def convert_mean( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch mean reduction to TensorRT reduce layer. + + Args: + node: FX node representing the mean operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: ExportedProgram for extracting parameters. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_mean") from e + + logger.debug(f"[TensorRT] Converting mean node: {node.name}") + + args = node.args + kwargs = node.kwargs + + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to mean must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get dimensions - can be a single int or a list + dims = args[1] if len(args) > 1 else kwargs.get("dim", None) + + # Get keepdim parameter + keepdim = args[2] if len(args) > 2 else kwargs.get("keepdim", False) + + # Get input dimensions from node metadata for reliability + input_shape = get_node_shape(input_node) or tuple(input_trt.shape) + ndim = len(input_shape) + + if dims is None: + # Reduce over all dimensions + axes = (1 << ndim) - 1 # All bits set + else: + axes = _get_reduce_axes(dims, ndim) + + logger.debug( + f"[TensorRT] mean reduction: dims={dims}, keepdim={keepdim}, axes={axes:b}" + ) + + # Create reduce layer with AVERAGE operation + layer = network.add_reduce(input_trt, trt.ReduceOperation.AVG, axes, keepdim) + + if layer is None: + raise RuntimeError(f"Failed to create reduce layer for mean {node.name}") + + layer.name = f"mean_{node.name}" + logger.debug(f"[TensorRT] Created mean reduce layer: {layer.name}") + + return layer.get_output(0) + + +def validate_sum(node: torch.fx.Node) -> bool: + """ + Validate that a sum node can be converted to TensorRT. + + Args: + node: FX node representing the sum operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_sum: node {node.name} is not call_function" + ) + return False + + return True + + +@converter( + "aten.sum.dim_IntList", + "aten.sum.default", + validator_fn=validate_sum, +) +def convert_sum( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch sum reduction to TensorRT reduce layer. + + Args: + node: FX node representing the sum operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: ExportedProgram for extracting parameters. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_sum") from e + + logger.debug(f"[TensorRT] Converting sum node: {node.name}") + + args = node.args + kwargs = node.kwargs + + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to sum must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get dimensions - can be a single int or a list + dims = args[1] if len(args) > 1 else kwargs.get("dim", None) + + # Get keepdim parameter + keepdim = args[2] if len(args) > 2 else kwargs.get("keepdim", False) + + # Get input dimensions from node metadata for reliability + input_shape = get_node_shape(input_node) or tuple(input_trt.shape) + ndim = len(input_shape) + + if dims is None: + # Reduce over all dimensions + axes = (1 << ndim) - 1 # All bits set + else: + axes = _get_reduce_axes(dims, ndim) + + logger.debug( + f"[TensorRT] sum reduction: dims={dims}, keepdim={keepdim}, axes={axes:b}" + ) + + # Create reduce layer with SUM operation + layer = network.add_reduce(input_trt, trt.ReduceOperation.SUM, axes, keepdim) + + if layer is None: + raise RuntimeError(f"Failed to create reduce layer for sum {node.name}") + + layer.name = f"sum_{node.name}" + logger.debug(f"[TensorRT] Created sum reduce layer: {layer.name}") + + return layer.get_output(0) + + +__all__ = [ + "convert_mean", + "convert_sum", + "validate_mean", + "validate_sum", +] diff --git a/backends/nvidia/tensorrt/converters/reshape.py b/backends/nvidia/tensorrt/converters/reshape.py new file mode 100644 index 00000000000..628ea2c32fa --- /dev/null +++ b/backends/nvidia/tensorrt/converters/reshape.py @@ -0,0 +1,868 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for Reshape and View Operations. + +This module provides converters for PyTorch tensor reshaping operations to TensorRT +shuffle layers. + +Supported operations: +- aten.view.default: View tensor with new shape +- aten.reshape.default: Reshape tensor +- aten.flatten.using_ints: Flatten specified dimensions +- aten.squeeze.dim: Squeeze a specific dimension (remove size-1 dim) +- aten.unsqueeze.default: Unsqueeze (add dimension) +- aten.permute.default: Permute dimensions +- aten.transpose.int: Transpose two dimensions + +Notes: +- All reshaping uses network.add_shuffle() +- shuffle.reshape_dims for shape changes +- shuffle.first_transpose for permutations +- Dynamic shapes require runtime shape computation +""" + +import logging +from typing import Any, Dict, List, Optional + +import tensorrt as trt +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import ( + get_node_shape, + get_trt_tensor_from_node, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +def _get_positive_dim(dim: int, ndim: int) -> int: + """ + Convert a potentially negative dimension index to positive. + + Args: + dim: Dimension index (can be negative). + ndim: Number of dimensions. + + Returns: + Positive dimension index. + """ + if dim < 0: + dim = ndim + dim + return dim + + +def validate_view_reshape(node: torch.fx.Node) -> bool: + """ + Validate that a view/reshape node can be converted to TensorRT. + + Args: + node: FX node representing the operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + logger.debug( + f"[TensorRT] validate_view_reshape: node {node.name} is not call_function" + ) + return False + + args = node.args + # Args: input, size (list of dims) + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_view_reshape: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_flatten(node: torch.fx.Node) -> bool: + """ + Validate that a flatten node can be converted to TensorRT. + + Args: + node: FX node representing the flatten operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + # Args: input, start_dim, end_dim + if len(args) < 3: + logger.debug( + f"[TensorRT] validate_flatten: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_squeeze_unsqueeze(node: torch.fx.Node) -> bool: + """ + Validate that a squeeze/unsqueeze node can be converted to TensorRT. + + Args: + node: FX node representing the squeeze/unsqueeze operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + # Args: input, dim + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_squeeze_unsqueeze: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_permute(node: torch.fx.Node) -> bool: + """ + Validate that a permute node can be converted to TensorRT. + + Args: + node: FX node representing the permute operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + # Args: input, dims (list of permutation) + if len(args) < 2: + logger.debug( + f"[TensorRT] validate_permute: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_transpose(node: torch.fx.Node) -> bool: + """ + Validate that a transpose node can be converted to TensorRT. + + Args: + node: FX node representing the transpose operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + # Args: input, dim0, dim1 + if len(args) < 3: + logger.debug( + f"[TensorRT] validate_transpose: node {node.name} has insufficient args" + ) + return False + + return True + + +def validate_select(node: torch.fx.Node) -> bool: + """ + Validate that a select node can be converted to TensorRT. + + Args: + node: FX node representing the select operation. + + Returns: + True if the node can be converted, False otherwise. + """ + if node.op != "call_function": + return False + + args = node.args + # Args: input, dim, index + if len(args) < 3: + logger.debug( + f"[TensorRT] validate_select: node {node.name} has insufficient args" + ) + return False + + return True + + +def _compute_view_output_shape( + node: torch.fx.Node, + input_node: torch.fx.Node, + input_trt: trt.ITensor, + target_shape: List[int], +) -> List[int]: + """Compute the output shape for view/reshape operations. + + Handles -1 dimension computation by calculating from input volume. + + Args: + node: FX node representing the operation (used for metadata). + input_node: FX node for the input tensor. + input_trt: TensorRT tensor for the input. + target_shape: Target shape specification (may contain -1). + + Returns: + Computed output shape with -1 resolved. + + Raises: + ValueError: If more than one -1 dimension is specified. + """ + # Prefer output shape from node metadata (most reliable) + if "val" in node.meta and hasattr(node.meta["val"], "shape"): + return list(node.meta["val"].shape) + + # Fall back to computing shape from target_shape, handling -1 + input_shape = list(get_node_shape(input_node) or input_trt.shape) + + # Calculate total input volume + input_volume = 1 + for d in input_shape: + if d > 0: + input_volume *= d + + # Process target_shape, computing -1 dimensions + output_shape = [] + neg_one_idx = -1 + known_volume = 1 + + for i, dim in enumerate(target_shape): + if isinstance(dim, int): + if dim == -1: + if neg_one_idx >= 0: + raise ValueError( + f"Only one -1 dimension allowed in view/reshape, " + f"found multiple at indices {neg_one_idx} and {i}" + ) + neg_one_idx = i + output_shape.append(-1) # Placeholder + else: + output_shape.append(dim) + if dim > 0: + known_volume *= dim + else: + # Non-integer dimension (e.g., symbolic) - use 0 for TRT dynamic + output_shape.append(0) + + # Calculate the -1 dimension if present + if neg_one_idx >= 0: + if known_volume > 0: + output_shape[neg_one_idx] = input_volume // known_volume + else: + # Cannot compute -1 dimension without knowing other dimensions + output_shape[neg_one_idx] = 0 # Use 0 for TRT dynamic inference + + return output_shape + + +@converter("aten.view.default", "aten._unsafe_view.default", "aten.view_copy.default", validator_fn=validate_view_reshape) +def convert_view( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch view/unsafe_view to TensorRT shuffle layer. + + Args: + node: FX node representing the view operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid or not found in input_map. + RuntimeError: If TensorRT layer creation fails. + """ + args = node.args + + input_node = args[0] + target_shape = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to view must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError( + f"Input node '{input_node.name}' not found in input_map for " + f"view node '{node.name}'" + ) + + input_trt = input_map[input_node] + + # Get the actual output shape from node metadata if available + # This is more reliable than TensorRT's dimension inference + output_shape = _compute_view_output_shape(node, input_node, input_trt, target_shape) + logger.debug(f"[TensorRT] view {node.name}: output_shape = {output_shape}") + + # Create shuffle layer for reshape + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for view {node.name}") + + layer.reshape_dims = trt.Dims(output_shape) + layer.name = f"view_{node.name}" + + return layer.get_output(0) + + +@converter("aten.reshape.default", validator_fn=validate_view_reshape) +def convert_reshape( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch reshape to TensorRT shuffle layer. + + Args: + node: FX node representing the reshape operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid or target_shape is malformed. + RuntimeError: If TensorRT layer creation fails. + """ + logger.debug(f"[TensorRT] Converting reshape node: {node.name}") + + args = node.args + + input_node = args[0] + target_shape = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to reshape must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError( + f"Input node '{input_node.name}' not found in input_map for " + f"reshape node '{node.name}'" + ) + + if not isinstance(target_shape, (list, tuple)): + raise ValueError(f"target_shape must be list or tuple, got {type(target_shape)}") + + input_trt = input_map[input_node] + + # Use the same shape computation logic as convert_view for consistency + output_shape = _compute_view_output_shape(node, input_node, input_trt, list(target_shape)) + + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for reshape {node.name}") + + layer.reshape_dims = trt.Dims(output_shape) + layer.name = f"reshape_{node.name}" + logger.debug(f"[TensorRT] Created reshape layer: {layer.name}, shape={output_shape}") + + return layer.get_output(0) + + +@converter("aten.flatten.using_ints", validator_fn=validate_flatten) +def convert_flatten( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch flatten to TensorRT shuffle layer. + + Flatten merges dimensions from start_dim to end_dim (inclusive). + + Args: + node: FX node representing the flatten operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid or dimensions are inconsistent. + RuntimeError: If TensorRT layer creation fails. + """ + logger.debug(f"[TensorRT] Converting flatten node: {node.name}") + + args = node.args + + input_node = args[0] + start_dim = args[1] + end_dim = args[2] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to flatten must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError( + f"Input node '{input_node.name}' not found in input_map for " + f"flatten node '{node.name}'" + ) + + input_trt = input_map[input_node] + + # Get shape from node metadata for reliability + input_shape = tuple(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Handle negative dimensions + start_dim = _get_positive_dim(start_dim, ndim) + end_dim = _get_positive_dim(end_dim, ndim) + + # Validate dimensions + if start_dim > end_dim: + raise ValueError(f"start_dim ({start_dim}) must be <= end_dim ({end_dim})") + + # Build the output shape + output_shape = [] + flatten_size = 1 + + for i, s in enumerate(input_shape): + if i < start_dim: + output_shape.append(s) + elif i <= end_dim: + if s == -1: + # Dynamic dimension - use 0 for TensorRT inference + flatten_size = 0 + elif flatten_size != 0: + flatten_size *= s + else: + if flatten_size is not None: + output_shape.append(flatten_size if flatten_size != 0 else 0) + flatten_size = None # Mark as already added + output_shape.append(s) + + # Add the flattened dimension if not yet added (end_dim is last dim) + if flatten_size is not None: + output_shape.append(flatten_size if flatten_size != 0 else 0) + + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for flatten {node.name}") + + layer.reshape_dims = trt.Dims(output_shape) + layer.name = f"flatten_{node.name}" + + logger.debug( + f"[TensorRT] Created flatten layer: {layer.name}, " + f"start_dim={start_dim}, end_dim={end_dim}, output_shape={output_shape}" + ) + + return layer.get_output(0) + + +@converter("aten.squeeze.dim", "aten.squeeze.dims", "aten.squeeze_copy.dim", "aten.squeeze_copy.dims", validator_fn=validate_squeeze_unsqueeze) +def convert_squeeze( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch squeeze to TensorRT shuffle layer. + + Removes dimension of size 1 at the specified position. + + Args: + node: FX node representing the squeeze operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid. + RuntimeError: If TensorRT layer creation fails. + """ + logger.debug(f"[TensorRT] Converting squeeze node: {node.name}") + + args = node.args + + input_node = args[0] + dim = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to squeeze must be a node, got {type(input_node)}") + + input_trt = get_trt_tensor_from_node(network, input_node, input_map, node.name) + + # Get shape from node metadata for reliability + input_shape = tuple(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Handle dims as list (squeeze.dims variant) + if isinstance(dim, (list, tuple)): + dims_to_squeeze = [_get_positive_dim(d, ndim) for d in dim] + else: + dims_to_squeeze = [_get_positive_dim(dim, ndim)] + + # Build output shape excluding squeezed dimensions + output_shape = [] + for i, s in enumerate(input_shape): + if i in dims_to_squeeze: + # Only squeeze if size is 1 or dynamic + if s != 1 and s != -1: + logger.warning( + f"[TensorRT] squeeze on dim {i} with size {s} != 1, not squeezing" + ) + output_shape.append(s) + # else: skip this dimension (squeeze it) + else: + output_shape.append(s) + + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for squeeze {node.name}") + + layer.reshape_dims = trt.Dims(output_shape) + layer.name = f"squeeze_{node.name}" + + logger.debug( + f"[TensorRT] Created squeeze layer: {layer.name}, " + f"dims={dims_to_squeeze}, output_shape={output_shape}" + ) + + return layer.get_output(0) + + +@converter("aten.unsqueeze.default", "aten.unsqueeze_copy.default", validator_fn=validate_squeeze_unsqueeze) +def convert_unsqueeze( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch unsqueeze to TensorRT shuffle layer. + + Inserts a dimension of size 1 at the specified position. + + Args: + node: FX node representing the unsqueeze operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid. + RuntimeError: If TensorRT layer creation fails. + """ + logger.debug(f"[TensorRT] Converting unsqueeze node: {node.name}") + + args = node.args + + input_node = args[0] + dim = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to unsqueeze must be a node, got {type(input_node)}") + + input_trt = get_trt_tensor_from_node(network, input_node, input_map, node.name) + + # Get shape from node metadata for reliability + input_shape = list(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Handle negative dimension (for unsqueeze, target ndim is ndim + 1) + dim = _get_positive_dim(dim, ndim + 1) + + # Build output shape with new dimension of size 1 + output_shape = input_shape[:dim] + [1] + input_shape[dim:] + + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for unsqueeze {node.name}") + + layer.reshape_dims = trt.Dims(output_shape) + layer.name = f"unsqueeze_{node.name}" + + logger.debug( + f"[TensorRT] Created unsqueeze layer: {layer.name}, " + f"dim={dim}, output_shape={output_shape}" + ) + + return layer.get_output(0) + + +@converter("aten.permute.default", "aten.permute_copy.default", validator_fn=validate_permute) +def convert_permute( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch permute to TensorRT shuffle layer. + + Reorders dimensions according to the specified permutation. + + Args: + node: FX node representing the permute operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid or dims is malformed. + RuntimeError: If TensorRT layer creation fails. + """ + args = node.args + + input_node = args[0] + dims = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to permute must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError( + f"Input node '{input_node.name}' not found in input_map for " + f"permute node '{node.name}'" + ) + + if not isinstance(dims, (list, tuple)): + raise ValueError(f"dims must be list or tuple, got {type(dims)}") + + input_trt = input_map[input_node] + + # Get ndim from node metadata for reliability + ndim = len(get_node_shape(input_node) or input_trt.shape) + + # Convert dims to list and handle negative indices + permutation = [_get_positive_dim(d, ndim) for d in dims] + + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for permute {node.name}") + + # Use first_transpose for permutation (applies before any reshape) + layer.first_transpose = trt.Permutation(permutation) + layer.name = f"permute_{node.name}" + + logger.debug( + f"[TensorRT] Created permute layer: {layer.name}, permutation={permutation}" + ) + + return layer.get_output(0) + + +@converter("aten.transpose.int", validator_fn=validate_transpose) +def convert_transpose( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch transpose to TensorRT shuffle layer. + + Swaps two dimensions of the input tensor. + + Args: + node: FX node representing the transpose operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid. + RuntimeError: If TensorRT layer creation fails. + """ + logger.debug(f"[TensorRT] Converting transpose node: {node.name}") + + args = node.args + + input_node = args[0] + dim0 = args[1] + dim1 = args[2] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to transpose must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError( + f"Input node '{input_node.name}' not found in input_map for " + f"transpose node '{node.name}'" + ) + + input_trt = input_map[input_node] + + # Get ndim from node metadata for reliability + ndim = len(get_node_shape(input_node) or input_trt.shape) + + # Handle negative dimensions + dim0 = _get_positive_dim(dim0, ndim) + dim1 = _get_positive_dim(dim1, ndim) + + # Build permutation: identity with dim0 and dim1 swapped + permutation = list(range(ndim)) + permutation[dim0] = dim1 + permutation[dim1] = dim0 + + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for transpose {node.name}") + + layer.first_transpose = trt.Permutation(permutation) + layer.name = f"transpose_{node.name}" + + logger.debug( + f"[TensorRT] Created transpose layer: {layer.name}, " + f"dim0={dim0}, dim1={dim1}, permutation={permutation}" + ) + + return layer.get_output(0) + + +@converter("aten.select.int", "aten.select_copy.int", validator_fn=validate_select) +def convert_select( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """ + Convert PyTorch select to TensorRT slice layer. + + Select extracts a slice of size 1 along a dimension and removes that dimension. + Equivalent to tensor[dim, index]. + + Args: + node: FX node representing the select operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + edge_program: Optional edge program (unused). + + Returns: + TensorRT output tensor. + + Raises: + ValueError: If input is invalid. + RuntimeError: If TensorRT layer creation fails. + """ + logger.debug(f"[TensorRT] Converting select node: {node.name}") + + args = node.args + + input_node = args[0] + dim = args[1] + index = args[2] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to select must be a node, got {type(input_node)}") + + input_trt = get_trt_tensor_from_node(network, input_node, input_map, node.name) + + # Get shape from node metadata for reliability + input_shape = list(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Handle negative dimension + dim = _get_positive_dim(dim, ndim) + + # Handle negative index + if index < 0: + index = input_shape[dim] + index + + # Build start, shape, stride for slice operation + start = [0] * ndim + start[dim] = index + + # Shape: same as input except the selected dim has size 1 + shape = input_shape.copy() + shape[dim] = 1 + + # Stride: 1 for all dims + stride = [1] * ndim + + # Create slice layer + layer = network.add_slice( + input_trt, + start=trt.Dims(start), + shape=trt.Dims(shape), + stride=trt.Dims(stride), + ) + if layer is None: + raise RuntimeError(f"Failed to create slice layer for select {node.name}") + + layer.name = f"select_slice_{node.name}" + slice_output = layer.get_output(0) + + # Now squeeze the dimension to remove the size-1 dim + output_shape = input_shape[:dim] + input_shape[dim + 1:] + + squeeze_layer = network.add_shuffle(slice_output) + if squeeze_layer is None: + raise RuntimeError( + f"Failed to create shuffle layer for select squeeze {node.name}" + ) + + squeeze_layer.reshape_dims = trt.Dims(output_shape) + squeeze_layer.name = f"select_squeeze_{node.name}" + + logger.debug( + f"[TensorRT] Created select layer: {layer.name}, " + f"dim={dim}, index={index}, output_shape={output_shape}" + ) + + return squeeze_layer.get_output(0) + + +__all__ = [ + "convert_view", + "convert_reshape", + "convert_flatten", + "convert_squeeze", + "convert_unsqueeze", + "convert_permute", + "convert_transpose", + "convert_select", + "validate_view_reshape", + "validate_flatten", + "validate_squeeze_unsqueeze", + "validate_permute", + "validate_transpose", + "validate_select", + "_compute_view_output_shape", +] diff --git a/backends/nvidia/tensorrt/converters/targets.bzl b/backends/nvidia/tensorrt/converters/targets.bzl index 412b5b2c2ff..7f254f644f6 100644 --- a/backends/nvidia/tensorrt/converters/targets.bzl +++ b/backends/nvidia/tensorrt/converters/targets.bzl @@ -11,13 +11,24 @@ def define_common_targets(): name = "converters", srcs = [ "__init__.py", + "activations.py", "add.py", "addmm.py", + "batch_norm.py", + "clamp.py", + "concat.py", + "conv2d.py", + "dim_order_ops.py", "div.py", + "getitem.py", + "linear.py", "mm.py", "mul.py", "permute_copy.py", + "pooling.py", + "reduction.py", "relu.py", + "reshape.py", "sub.py", ], visibility = ["PUBLIC"], diff --git a/examples/nvidia/tensorrt/export.py b/examples/nvidia/tensorrt/export.py index 1d14605aeb2..c70aef25620 100644 --- a/examples/nvidia/tensorrt/export.py +++ b/examples/nvidia/tensorrt/export.py @@ -37,6 +37,7 @@ "add_mul", "linear", "mul", + "mv3", "softmax", } diff --git a/examples/nvidia/tensorrt/tests/test_export.py b/examples/nvidia/tensorrt/tests/test_export.py index acaf1bdda46..09b17de1a36 100644 --- a/examples/nvidia/tensorrt/tests/test_export.py +++ b/examples/nvidia/tensorrt/tests/test_export.py @@ -99,5 +99,8 @@ def test_add_bf16(self) -> None: def test_softmax(self) -> None: _export_and_verify("softmax") + def test_mv3(self) -> None: + _export_and_verify("mv3") + def test_linear(self) -> None: _export_and_verify("linear")