diff --git a/backends/nvidia/tensorrt/README.md b/backends/nvidia/tensorrt/README.md index 00dd7ac3ce5..4c7d4e4e46e 100644 --- a/backends/nvidia/tensorrt/README.md +++ b/backends/nvidia/tensorrt/README.md @@ -162,6 +162,23 @@ with open("model_tensorrt.pte", "wb") as f: | Elementwise | add, sub, mul, div, mm, relu | | Reshape | view, reshape, squeeze, unsqueeze, permute, contiguous, clone | +## Supported Operations + +| Category | Operations | +|----------|-----------| +| Elementwise | add, sub, mul, div, floor_divide, rsub | +| Matrix | mm, addmm, bmm, linear | +| Convolution | conv2d | +| Normalization | batch_norm, layer_norm | +| Pooling | avg_pool2d, adaptive_avg_pool2d | +| Activations | relu, sigmoid, tanh, gelu, silu, hardswish, hardsigmoid, softmax, log_softmax, clamp | +| Reshape | view, reshape, squeeze, unsqueeze, permute, transpose, flatten, unflatten, contiguous, clone | +| Reduction | mean, any | +| Concat/Split | cat, split, chunk, stack | +| Comparison | eq, ne, lt, le, gt, ge, where, logical_not | +| Slicing | slice, select, index | +| Other | embedding, expand, repeat, upsample, pixel_shuffle, scaled_dot_product_attention, full | + ## Jetson Deployment ### Performance Tuning diff --git a/backends/nvidia/tensorrt/converters/__init__.py b/backends/nvidia/tensorrt/converters/__init__.py index 58f6218e01a..2f99cc6f5ca 100644 --- a/backends/nvidia/tensorrt/converters/__init__.py +++ b/backends/nvidia/tensorrt/converters/__init__.py @@ -13,6 +13,7 @@ from executorch.backends.nvidia.tensorrt.converters import batch_norm # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import bmm # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import clamp # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import comparison # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import concat # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import conv2d # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import dim_order_ops # noqa: F401 @@ -31,6 +32,7 @@ from executorch.backends.nvidia.tensorrt.converters import relu # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import reshape # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import sdpa # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import slice # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import sub # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import upsample # noqa: F401 diff --git a/backends/nvidia/tensorrt/converters/comparison.py b/backends/nvidia/tensorrt/converters/comparison.py new file mode 100644 index 00000000000..e3b32767a04 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/comparison.py @@ -0,0 +1,1082 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for Comparison and Logical Operations. + +This module provides converters for PyTorch comparison and logical operations +to TensorRT layers. + +Supported operations: +- aten.eq.Scalar, aten.ne.Scalar, etc.: Comparison with scalar +- aten.logical_not.default: Logical NOT +- aten.where.self: Conditional selection +- aten.any.dim, aten.all.dim: Boolean reduction +- aten.full_like.default: Create tensor filled with value +""" + +import logging +from typing import Any, Dict, Optional + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import ( + get_node_shape, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +@converter("aten.eq.Scalar") +def convert_eq_scalar( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch eq (equal) with scalar to TensorRT. + + eq.Scalar(Tensor self, Scalar other) -> Tensor + + Args: + node: FX node representing the eq operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor (boolean). + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + args = node.args + input_node = args[0] + other = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to eq must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + # Get input shape from node metadata for proper broadcasting + if isinstance(input_node, torch.fx.Node) and "val" in input_node.meta and hasattr(input_node.meta["val"], "shape"): + input_shape = list(input_node.meta["val"].shape) + else: + # Fall back to TRT tensor shape (may be invalid during error conditions) + try: + input_shape = list(input_trt.shape) + except (TypeError, ValueError): + input_shape = [1] # Fallback + + # Create constant for scalar with shape that can broadcast to input shape + # Use shape [1, 1, ...] with same ndim as input for proper broadcasting + ndim = len(input_shape) + const_shape = [1] * ndim if ndim > 0 else [1] + other_data = np.full(const_shape, other, dtype=np.float32) + other_weights = trt.Weights(other_data) + other_const = network.add_constant(const_shape, other_weights) + other_const.name = f"eq_const_{node.name}" + + # TensorRT EQUAL comparison + layer = network.add_elementwise( + input_trt, other_const.get_output(0), trt.ElementWiseOperation.EQUAL + ) + + if layer is None: + raise RuntimeError(f"Failed to create eq layer for {node.name}") + + layer.name = f"eq_scalar_{node.name}" + return layer.get_output(0) + + +@converter("aten.ne.Scalar") +def convert_ne_scalar( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch ne (not equal) with scalar to TensorRT.""" + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + args = node.args + input_node = args[0] + other = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to ne must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + # Create constant for scalar + other_weights = trt.Weights(np.array([other], dtype=np.float32)) + other_const = network.add_constant([1], other_weights) + other_const.name = f"ne_const_{node.name}" + + # First compute EQUAL, then NOT + eq_layer = network.add_elementwise( + input_trt, other_const.get_output(0), trt.ElementWiseOperation.EQUAL + ) + eq_layer.name = f"ne_eq_{node.name}" + + # Logical NOT for boolean tensors + not_layer = network.add_unary(eq_layer.get_output(0), trt.UnaryOperation.NOT) + + if not_layer is None: + raise RuntimeError(f"Failed to create ne layer for {node.name}") + + not_layer.name = f"ne_scalar_{node.name}" + return not_layer.get_output(0) + + +@converter("aten.lt.Scalar") +def convert_lt_scalar( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch lt (less than) with scalar to TensorRT.""" + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + args = node.args + input_node = args[0] + other = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to lt must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + other_weights = trt.Weights(np.array([other], dtype=np.float32)) + other_const = network.add_constant([1], other_weights) + other_const.name = f"lt_const_{node.name}" + + layer = network.add_elementwise( + input_trt, other_const.get_output(0), trt.ElementWiseOperation.LESS + ) + + if layer is None: + raise RuntimeError(f"Failed to create lt layer for {node.name}") + + layer.name = f"lt_scalar_{node.name}" + return layer.get_output(0) + + +@converter("aten.gt.Scalar") +def convert_gt_scalar( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch gt (greater than) with scalar to TensorRT.""" + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + args = node.args + input_node = args[0] + other = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to gt must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + other_weights = trt.Weights(np.array([other], dtype=np.float32)) + other_const = network.add_constant([1], other_weights) + other_const.name = f"gt_const_{node.name}" + + layer = network.add_elementwise( + input_trt, other_const.get_output(0), trt.ElementWiseOperation.GREATER + ) + + if layer is None: + raise RuntimeError(f"Failed to create gt layer for {node.name}") + + layer.name = f"gt_scalar_{node.name}" + return layer.get_output(0) + + +@converter("aten.ge.Scalar") +def convert_ge_scalar( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch ge (greater than or equal) with scalar to TensorRT. + + ge(x, y) = gt(x, y) OR eq(x, y) + + TensorRT doesn't have a native GE operation, so we implement it as + the logical OR of GREATER and EQUAL operations. + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + args = node.args + input_node = args[0] + other = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to ge must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + # Get input shape for proper broadcasting + if ( + isinstance(input_node, torch.fx.Node) + and "val" in input_node.meta + and hasattr(input_node.meta["val"], "shape") + ): + input_shape = list(input_node.meta["val"].shape) + else: + try: + input_shape = list(input_trt.shape) + except (TypeError, ValueError): + input_shape = [1] + + # Create constant with proper shape for broadcasting + ndim = len(input_shape) + const_shape = [1] * ndim if ndim > 0 else [1] + other_data = np.full(const_shape, other, dtype=np.float32) + other_weights = trt.Weights(other_data) + other_const = network.add_constant(const_shape, other_weights) + other_const.name = f"ge_const_{node.name}" + other_tensor = other_const.get_output(0) + + # Compute GREATER + gt_layer = network.add_elementwise( + input_trt, other_tensor, trt.ElementWiseOperation.GREATER + ) + if gt_layer is None: + raise RuntimeError(f"Failed to create gt layer for ge_{node.name}") + gt_layer.name = f"ge_gt_{node.name}" + + # Compute EQUAL + eq_layer = network.add_elementwise( + input_trt, other_tensor, trt.ElementWiseOperation.EQUAL + ) + if eq_layer is None: + raise RuntimeError(f"Failed to create eq layer for ge_{node.name}") + eq_layer.name = f"ge_eq_{node.name}" + + # Compute OR (gt OR eq) + or_layer = network.add_elementwise( + gt_layer.get_output(0), eq_layer.get_output(0), trt.ElementWiseOperation.OR + ) + if or_layer is None: + raise RuntimeError(f"Failed to create or layer for ge_{node.name}") + or_layer.name = f"ge_scalar_{node.name}" + + return or_layer.get_output(0) + + +@converter("aten.le.Scalar") +def convert_le_scalar( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch le (less than or equal) with scalar to TensorRT. + + le(x, y) = lt(x, y) OR eq(x, y) + + TensorRT doesn't have a native LE operation, so we implement it as + the logical OR of LESS and EQUAL operations. + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + args = node.args + input_node = args[0] + other = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to le must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + # Get input shape for proper broadcasting + if ( + isinstance(input_node, torch.fx.Node) + and "val" in input_node.meta + and hasattr(input_node.meta["val"], "shape") + ): + input_shape = list(input_node.meta["val"].shape) + else: + try: + input_shape = list(input_trt.shape) + except (TypeError, ValueError): + input_shape = [1] + + # Create constant with proper shape for broadcasting + ndim = len(input_shape) + const_shape = [1] * ndim if ndim > 0 else [1] + other_data = np.full(const_shape, other, dtype=np.float32) + other_weights = trt.Weights(other_data) + other_const = network.add_constant(const_shape, other_weights) + other_const.name = f"le_const_{node.name}" + other_tensor = other_const.get_output(0) + + # Compute LESS + lt_layer = network.add_elementwise( + input_trt, other_tensor, trt.ElementWiseOperation.LESS + ) + if lt_layer is None: + raise RuntimeError(f"Failed to create lt layer for le_{node.name}") + lt_layer.name = f"le_lt_{node.name}" + + # Compute EQUAL + eq_layer = network.add_elementwise( + input_trt, other_tensor, trt.ElementWiseOperation.EQUAL + ) + if eq_layer is None: + raise RuntimeError(f"Failed to create eq layer for le_{node.name}") + eq_layer.name = f"le_eq_{node.name}" + + # Compute OR (lt OR eq) + or_layer = network.add_elementwise( + lt_layer.get_output(0), eq_layer.get_output(0), trt.ElementWiseOperation.OR + ) + if or_layer is None: + raise RuntimeError(f"Failed to create or layer for le_{node.name}") + or_layer.name = f"le_scalar_{node.name}" + + return or_layer.get_output(0) + + +@converter("aten.ge.Tensor") +def convert_ge_tensor( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch ge (greater than or equal) with tensor to TensorRT. + + ge(x, y) = gt(x, y) OR eq(x, y) + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + input_node = args[0] + other_node = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to ge must be a node, got {type(input_node)}") + if not isinstance(other_node, torch.fx.Node): + raise ValueError(f"Other input to ge must be a node, got {type(other_node)}") + + input_trt = input_map[input_node] + other_trt = input_map[other_node] + + # Compute GREATER + gt_layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.GREATER + ) + if gt_layer is None: + raise RuntimeError(f"Failed to create gt layer for ge_{node.name}") + gt_layer.name = f"ge_gt_{node.name}" + + # Compute EQUAL + eq_layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.EQUAL + ) + if eq_layer is None: + raise RuntimeError(f"Failed to create eq layer for ge_{node.name}") + eq_layer.name = f"ge_eq_{node.name}" + + # Compute OR (gt OR eq) + or_layer = network.add_elementwise( + gt_layer.get_output(0), eq_layer.get_output(0), trt.ElementWiseOperation.OR + ) + if or_layer is None: + raise RuntimeError(f"Failed to create or layer for ge_{node.name}") + or_layer.name = f"ge_tensor_{node.name}" + + return or_layer.get_output(0) + + +@converter("aten.le.Tensor") +def convert_le_tensor( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch le (less than or equal) with tensor to TensorRT. + + le(x, y) = lt(x, y) OR eq(x, y) + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + input_node = args[0] + other_node = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to le must be a node, got {type(input_node)}") + if not isinstance(other_node, torch.fx.Node): + raise ValueError(f"Other input to le must be a node, got {type(other_node)}") + + input_trt = input_map[input_node] + other_trt = input_map[other_node] + + # Compute LESS + lt_layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.LESS + ) + if lt_layer is None: + raise RuntimeError(f"Failed to create lt layer for le_{node.name}") + lt_layer.name = f"le_lt_{node.name}" + + # Compute EQUAL + eq_layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.EQUAL + ) + if eq_layer is None: + raise RuntimeError(f"Failed to create eq layer for le_{node.name}") + eq_layer.name = f"le_eq_{node.name}" + + # Compute OR (lt OR eq) + or_layer = network.add_elementwise( + lt_layer.get_output(0), eq_layer.get_output(0), trt.ElementWiseOperation.OR + ) + if or_layer is None: + raise RuntimeError(f"Failed to create or layer for le_{node.name}") + or_layer.name = f"le_tensor_{node.name}" + + return or_layer.get_output(0) + + +@converter("aten.eq.Tensor") +def convert_eq_tensor( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch eq (equal) with tensor to TensorRT.""" + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + input_node = args[0] + other_node = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to eq must be a node, got {type(input_node)}") + if not isinstance(other_node, torch.fx.Node): + raise ValueError(f"Other input to eq must be a node, got {type(other_node)}") + + input_trt = input_map[input_node] + other_trt = input_map[other_node] + + layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.EQUAL + ) + if layer is None: + raise RuntimeError(f"Failed to create eq layer for {node.name}") + layer.name = f"eq_tensor_{node.name}" + + return layer.get_output(0) + + +@converter("aten.ne.Tensor") +def convert_ne_tensor( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch ne (not equal) with tensor to TensorRT.""" + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + input_node = args[0] + other_node = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to ne must be a node, got {type(input_node)}") + if not isinstance(other_node, torch.fx.Node): + raise ValueError(f"Other input to ne must be a node, got {type(other_node)}") + + input_trt = input_map[input_node] + other_trt = input_map[other_node] + + # Compute EQUAL + eq_layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.EQUAL + ) + if eq_layer is None: + raise RuntimeError(f"Failed to create eq layer for ne_{node.name}") + eq_layer.name = f"ne_eq_{node.name}" + + # Logical NOT + not_layer = network.add_unary(eq_layer.get_output(0), trt.UnaryOperation.NOT) + if not_layer is None: + raise RuntimeError(f"Failed to create not layer for ne_{node.name}") + not_layer.name = f"ne_tensor_{node.name}" + + return not_layer.get_output(0) + + +@converter("aten.lt.Tensor") +def convert_lt_tensor( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch lt (less than) with tensor to TensorRT.""" + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + input_node = args[0] + other_node = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to lt must be a node, got {type(input_node)}") + if not isinstance(other_node, torch.fx.Node): + raise ValueError(f"Other input to lt must be a node, got {type(other_node)}") + + input_trt = input_map[input_node] + other_trt = input_map[other_node] + + layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.LESS + ) + if layer is None: + raise RuntimeError(f"Failed to create lt layer for {node.name}") + layer.name = f"lt_tensor_{node.name}" + + return layer.get_output(0) + + +@converter("aten.gt.Tensor") +def convert_gt_tensor( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert PyTorch gt (greater than) with tensor to TensorRT.""" + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + input_node = args[0] + other_node = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to gt must be a node, got {type(input_node)}") + if not isinstance(other_node, torch.fx.Node): + raise ValueError(f"Other input to gt must be a node, got {type(other_node)}") + + input_trt = input_map[input_node] + other_trt = input_map[other_node] + + layer = network.add_elementwise( + input_trt, other_trt, trt.ElementWiseOperation.GREATER + ) + if layer is None: + raise RuntimeError(f"Failed to create gt layer for {node.name}") + layer.name = f"gt_tensor_{node.name}" + + return layer.get_output(0) + + +@converter("aten.logical_not.default") +def convert_logical_not( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """ + Convert PyTorch logical_not to TensorRT. + + logical_not.default(Tensor self) -> Tensor + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to logical_not must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + layer = network.add_unary(input_trt, trt.UnaryOperation.NOT) + + if layer is None: + raise RuntimeError(f"Failed to create logical_not layer for {node.name}") + + layer.name = f"logical_not_{node.name}" + return layer.get_output(0) + + +@converter("aten.where.self", "aten.where.ScalarSelf") +def convert_where( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """ + Convert PyTorch where to TensorRT select layer. + + where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + + condition_node = args[0] + self_node = args[1] + other_node = args[2] + + if not isinstance(condition_node, torch.fx.Node): + raise ValueError(f"Condition must be a node, got {type(condition_node)}") + + condition_trt = input_map[condition_node] + self_trt = input_map[self_node] if isinstance(self_node, torch.fx.Node) else None + other_trt = input_map[other_node] if isinstance(other_node, torch.fx.Node) else None + + # Handle scalar inputs + import numpy as np + if self_trt is None: + self_weights = trt.Weights(np.array([self_node], dtype=np.float32)) + self_const = network.add_constant([1], self_weights) + self_const.name = f"where_self_const_{node.name}" + self_trt = self_const.get_output(0) + + if other_trt is None: + other_weights = trt.Weights(np.array([other_node], dtype=np.float32)) + other_const = network.add_constant([1], other_weights) + other_const.name = f"where_other_const_{node.name}" + other_trt = other_const.get_output(0) + + # TensorRT select: output = condition ? self : other + layer = network.add_select(condition_trt, self_trt, other_trt) + + if layer is None: + raise RuntimeError(f"Failed to create where/select layer for {node.name}") + + layer.name = f"where_{node.name}" + return layer.get_output(0) + + +@converter("aten.any.dim") +def convert_any_dim( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """ + Convert PyTorch any.dim to TensorRT reduce layer. + + any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + Returns True if any value along dim is True. + + TensorRT doesn't have native boolean reduce, so we: + 1. Cast bool to float (true=1, false=0) + 2. Sum along dimension + 3. Compare > 0 + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT is required") from e + + args = node.args + kwargs = node.kwargs + + input_node = args[0] + dim = args[1] if len(args) > 1 else kwargs.get("dim") + keepdim = args[2] if len(args) > 2 else kwargs.get("keepdim", False) + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to any must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + + # Get ndim from node metadata for reliability (TRT shapes can be invalid during network building) + ndim = len(get_node_shape(input_node) or input_trt.shape) + + # Handle negative dim + if dim < 0: + dim = ndim + dim + + # Cast to float if boolean + identity = network.add_identity(input_trt) + identity.set_output_type(0, trt.float32) + identity.name = f"any_cast_{node.name}" + float_input = identity.get_output(0) + + # Reduce sum along dimension + axes = 1 << dim # bitmask + reduce_layer = network.add_reduce( + float_input, + trt.ReduceOperation.SUM, + axes, + keepdim, + ) + reduce_layer.name = f"any_sum_{node.name}" + + # Get the output shape from node metadata for proper broadcasting + if "val" in node.meta and hasattr(node.meta["val"], "shape"): + output_shape = list(node.meta["val"].shape) + else: + # Compute output shape based on reduction + input_shape = get_node_shape(input_node) or tuple(input_trt.shape) + output_shape = list(input_shape) + if keepdim: + output_shape[dim] = 1 + else: + output_shape.pop(dim) + + # Create zero constant with shape matching the reduced output for proper broadcasting + output_ndim = len(output_shape) + const_shape = [1] * output_ndim if output_ndim > 0 else [1] + zero_data = np.zeros(const_shape, dtype=np.float32) + zero_weights = trt.Weights(zero_data) + zero_const = network.add_constant(const_shape, zero_weights) + zero_const.name = f"any_zero_const_{node.name}" + + gt_layer = network.add_elementwise( + reduce_layer.get_output(0), + zero_const.get_output(0), + trt.ElementWiseOperation.GREATER, + ) + + if gt_layer is None: + raise RuntimeError(f"Failed to create any layer for {node.name}") + + gt_layer.name = f"any_{node.name}" + return gt_layer.get_output(0) + + +@converter("aten.full_like.default", "aten.zeros_like.default", "aten.ones_like.default") +def convert_full_like( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """ + Convert PyTorch full_like/zeros_like/ones_like to TensorRT. + + Creates a tensor filled with a constant value, matching the shape of input. + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + args = node.args + input_node = args[0] + + # Determine fill value based on op + target_name = str(node.target) + if "zeros_like" in target_name: + fill_value = 0.0 + elif "ones_like" in target_name: + fill_value = 1.0 + else: + # full_like - fill value is second arg + fill_value = args[1] if len(args) > 1 else 0.0 + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input must be a node, got {type(input_node)}") + + # Get shape from node metadata instead of TensorRT tensor + # This is more reliable as TensorRT shapes may have -1 for dynamic dims + if "val" in node.meta and hasattr(node.meta["val"], "shape"): + output_shape = list(node.meta["val"].shape) + elif "val" in input_node.meta and hasattr(input_node.meta["val"], "shape"): + output_shape = list(input_node.meta["val"].shape) + else: + # Fall back to TensorRT shape, but check for -1 values + input_trt = input_map[input_node] + output_shape = list(input_trt.shape) + if any(d == -1 for d in output_shape): + raise ValueError( + f"Cannot create full_like with dynamic shape {output_shape}. " + "Shape must be static." + ) + + logger.debug(f"[TensorRT] full_like: shape={output_shape}, fill_value={fill_value}") + + # Create constant filled with value + fill_array = np.full(output_shape, fill_value, dtype=np.float32) + fill_weights = trt.Weights(fill_array) + layer = network.add_constant(trt.Dims(output_shape), fill_weights) + + if layer is None: + raise RuntimeError(f"Failed to create full_like layer for {node.name}") + + layer.name = f"full_like_{node.name}" + return layer.get_output(0) + + +@converter("aten.full.default") +def convert_full( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """ + Convert PyTorch full to TensorRT. + + full.default(int[] size, Scalar fill_value, ...) -> Tensor + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + from executorch.backends.nvidia.tensorrt.converter_utils import ( + _torch_dtype_to_numpy, + ) + + args = node.args + + size = args[0] + fill_value = args[1] if len(args) > 1 else 0.0 + + # Convert size to list + if isinstance(size, (list, tuple)): + shape = list(size) + else: + shape = [size] + + # Determine dtype from node metadata or kwargs, defaulting to float32. + np_dtype = np.float32 + dtype_kwarg = node.kwargs.get("dtype", None) + if dtype_kwarg is not None and isinstance(dtype_kwarg, torch.dtype): + np_dtype = _torch_dtype_to_numpy(dtype_kwarg) + elif "val" in node.meta: + val = node.meta["val"] + if hasattr(val, "dtype"): + np_dtype = _torch_dtype_to_numpy(val.dtype) + + fill_array = np.full(shape, fill_value, dtype=np_dtype) + fill_weights = trt.Weights(fill_array) + layer = network.add_constant(trt.Dims(shape), fill_weights) + + if layer is None: + raise RuntimeError(f"Failed to create full layer for {node.name}") + + layer.name = f"full_{node.name}" + return layer.get_output(0) + + +@converter("aten.scalar_tensor.default") +def convert_scalar_tensor( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert aten.scalar_tensor to a TRT constant.""" + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + value = node.args[0] + np_value = np.array([float(value)], dtype=np.float32) + layer = network.add_constant([1], trt.Weights(np_value)) + layer.name = f"scalar_tensor_{node.name}" + return layer.get_output(0) + + +@converter("aten.arange.start_step") +def convert_arange( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert aten.arange.start_step to a TRT constant. + + Computes the range on the host and embeds it as a constant tensor. + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + start = node.args[0] + end = node.args[1] if len(node.args) > 1 else node.kwargs.get("end") + step = node.args[2] if len(node.args) > 2 else node.kwargs.get("step", 1) + + values = np.arange(start, end, step, dtype=np.float32) + layer = network.add_constant(list(values.shape), trt.Weights(values)) + layer.name = f"arange_{node.name}" + return layer.get_output(0) + + +@converter("aten.constant_pad_nd.default") +def convert_constant_pad_nd( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert aten.constant_pad_nd to TRT using ISliceLayer with FILL mode. + + pad format: [left_last_dim, right_last_dim, left_2nd_last, right_2nd_last, ...] + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required") from e + + input_node = node.args[0] + pad = list(node.args[1]) + value = float(node.args[2]) if len(node.args) > 2 else 0.0 + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input must be a node, got {type(input_node)}") + + input_trt = input_map[input_node] + input_shape = list(input_trt.shape) + ndim = len(input_shape) + num_padded_dims = len(pad) // 2 + + start = [0] * ndim + output_shape = list(input_shape) + + for i in range(num_padded_dims): + dim = ndim - 1 - i + left_pad = pad[2 * i] + right_pad = pad[2 * i + 1] + start[dim] = -left_pad + output_shape[dim] = input_shape[dim] + left_pad + right_pad + + layer = network.add_slice( + input_trt, start=start, shape=output_shape, stride=[1] * ndim + ) + layer.mode = trt.SampleMode.FILL + + fill_const = network.add_constant( + [1] * ndim, trt.Weights(np.array([value], dtype=np.float32)) + ) + fill_const.name = f"pad_fill_{node.name}" + layer.set_input(4, fill_const.get_output(0)) + + layer.name = f"constant_pad_nd_{node.name}" + return layer.get_output(0) + + +@converter("aten.alias_copy.default") +def convert_alias_copy( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Pass-through for alias_copy (no-op in TRT).""" + return input_map[node.args[0]] + + +@converter("aten.copy.default") +def convert_copy( + node: torch.fx.Node, + network: Any, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> Any: + """Convert aten.copy to TRT pass-through. + + PyTorch signature: aten.copy(Tensor self, Tensor src) -> Tensor + Copies src into self and returns self. For TRT, we return the src tensor + (arg[1]) since TRT manages memory internally — effectively a pass-through. + """ + src_node = node.args[1] + if isinstance(src_node, torch.fx.Node) and src_node in input_map: + return input_map[src_node] + raise ValueError(f"Source node for copy not found in input_map: {node.name}") + + +__all__ = [ + "convert_eq_scalar", + "convert_eq_tensor", + "convert_ne_scalar", + "convert_ne_tensor", + "convert_lt_scalar", + "convert_lt_tensor", + "convert_gt_scalar", + "convert_gt_tensor", + "convert_ge_scalar", + "convert_ge_tensor", + "convert_le_scalar", + "convert_le_tensor", + "convert_logical_not", + "convert_where", + "convert_any_dim", + "convert_full_like", + "convert_full", +] diff --git a/backends/nvidia/tensorrt/converters/slice.py b/backends/nvidia/tensorrt/converters/slice.py new file mode 100644 index 00000000000..f6b146eac70 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/slice.py @@ -0,0 +1,413 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TensorRT Converters for Slice Operations. + +This module provides converters for PyTorch tensor slicing operations to TensorRT +slice layers. + +Supported operations: +- aten.slice.Tensor: Slice along a dimension with start, end, step +- aten.slice_copy.Tensor: Copy variant of slice +- aten.index.Tensor: Index selection with tensor indices + +Notes: +- Slice uses network.add_slice() +- Start, shape, stride computed from slice parameters +- Handles negative indices and None values +""" + +import logging +import sys +from typing import Any, Dict, List, Optional + +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import ( + get_node_shape, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +def _get_positive_dim(dim: int, ndim: int) -> int: + """Convert a potentially negative dimension index to positive.""" + if dim < 0: + dim = ndim + dim + return dim + + +@converter("aten.slice.Tensor", "aten.slice_copy.Tensor") +def convert_slice( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch slice to TensorRT slice layer. + + slice.Tensor(Tensor self, int dim=0, int? start=None, int? end=None, int step=1) + + Args: + node: FX node representing the slice operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_slice") from e + + logger.debug(f"[TensorRT] Converting slice node: {node.name}") + + args = node.args + kwargs = node.kwargs + + input_node = args[0] + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to slice must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get shape from node metadata for reliability (TRT shapes can be invalid during network building) + input_shape = list(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Parse arguments with defaults + dim = args[1] if len(args) > 1 else kwargs.get("dim", 0) + start = args[2] if len(args) > 2 else kwargs.get("start", None) + end = args[3] if len(args) > 3 else kwargs.get("end", None) + step = args[4] if len(args) > 4 else kwargs.get("step", 1) + + # Handle None values + if start is None: + start = 0 + if end is None or end == sys.maxsize: + end = input_shape[dim] + + # Handle negative dimension + dim = _get_positive_dim(dim, ndim) + + # Handle negative indices + dim_size = input_shape[dim] + if isinstance(start, int) and start < 0: + start = max(0, dim_size + start) + if isinstance(end, int) and end < 0: + end = max(0, dim_size + end) + + # Clamp to valid range + if isinstance(start, int): + start = max(0, min(start, dim_size)) + if isinstance(end, int): + end = max(0, min(end, dim_size)) + + # Build start, shape, stride for slice + start_slice = [0] * ndim + start_slice[dim] = start + + # Compute output shape + output_shape = input_shape.copy() + if isinstance(end, int) and isinstance(start, int) and isinstance(step, int): + slice_len = max(0, (end - start + step - 1) // step) if step > 0 else 0 + output_shape[dim] = slice_len + else: + output_shape[dim] = 0 # Dynamic + + stride_slice = [1] * ndim + stride_slice[dim] = step + + layer = network.add_slice( + input_trt, + start=trt.Dims(start_slice), + shape=trt.Dims(output_shape), + stride=trt.Dims(stride_slice), + ) + + if layer is None: + raise RuntimeError(f"Failed to create slice layer for {node.name}") + + layer.name = f"slice_{node.name}" + + logger.debug( + f"[TensorRT] Created slice layer: {layer.name}, " + f"dim={dim}, start={start}, end={end}, step={step}, output_shape={output_shape}" + ) + + return layer.get_output(0) + + +@converter("aten.index.Tensor") +def convert_index( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch index to TensorRT gather layer. + + index.Tensor(Tensor self, Tensor?[] indices) + + This operation indexes into a tensor using tensor indices. + For simple 1D indexing along first dimension, we use gather. + + Args: + node: FX node representing the index operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_index") from e + + logger.debug(f"[TensorRT] Converting index node: {node.name}") + + args = node.args + + input_node = args[0] + indices = args[1] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to index must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # indices is a list of optional tensors, one per dimension + # Find the first non-None index + gather_dim = None + index_tensor = None + + for i, idx in enumerate(indices): + if idx is not None and isinstance(idx, torch.fx.Node): + if gather_dim is not None: + raise NotImplementedError( + "Multiple index tensors not supported, only single dimension indexing" + ) + gather_dim = i + index_tensor = input_map.get(idx) + + if gather_dim is None or index_tensor is None: + raise ValueError("No valid index tensor found in indices") + + # Create gather layer + layer = network.add_gather(input_trt, index_tensor, axis=gather_dim) + + if layer is None: + raise RuntimeError(f"Failed to create gather layer for index {node.name}") + + layer.name = f"index_gather_{node.name}" + + logger.debug( + f"[TensorRT] Created index/gather layer: {layer.name}, gather_dim={gather_dim}" + ) + + return layer.get_output(0) + + +@converter("aten.contiguous.default") +def convert_contiguous( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch contiguous to TensorRT. + + contiguous is a no-op in TensorRT as all tensors are contiguous. + + Args: + node: FX node representing the contiguous operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT input tensor (passthrough). + """ + logger.debug(f"[TensorRT] Converting contiguous node (no-op): {node.name}") + + args = node.args + input_node = args[0] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to contiguous must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + # contiguous is a no-op in TensorRT + return input_map[input_node] + + +@converter("aten.unflatten.int") +def convert_unflatten( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch unflatten to TensorRT shuffle layer. + + unflatten.int(Tensor self, int dim, int[] sizes) + + Unflatten a dimension of the input tensor into multiple dimensions. + + Args: + node: FX node representing the unflatten operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + except ImportError as e: + raise ImportError("TensorRT is required for convert_unflatten") from e + + logger.debug(f"[TensorRT] Converting unflatten node: {node.name}") + + args = node.args + + input_node = args[0] + dim = args[1] + sizes = args[2] + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to unflatten must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # Get shape from node metadata for reliability (TRT shapes can be invalid during network building) + input_shape = list(get_node_shape(input_node) or input_trt.shape) + ndim = len(input_shape) + + # Handle negative dimension + dim = _get_positive_dim(dim, ndim) + + # Build output shape by replacing dim with sizes + output_shape = input_shape[:dim] + list(sizes) + input_shape[dim + 1:] + + layer = network.add_shuffle(input_trt) + if layer is None: + raise RuntimeError(f"Failed to create shuffle layer for unflatten {node.name}") + + layer.reshape_dims = trt.Dims(output_shape) + layer.name = f"unflatten_{node.name}" + + logger.debug( + f"[TensorRT] Created unflatten layer: {layer.name}, " + f"dim={dim}, sizes={sizes}, output_shape={output_shape}" + ) + + return layer.get_output(0) + + +@converter("aten.rsub.Scalar") +def convert_rsub_scalar( + node: torch.fx.Node, + network: Any, # trt.INetworkDefinition + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] + edge_program: Optional[Any] = None, +) -> Any: # trt.ITensor + """ + Convert PyTorch rsub (reverse subtract) with scalar to TensorRT. + + rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + Computes: other - self * alpha + + Args: + node: FX node representing the rsub operation. + network: TensorRT network definition. + input_map: Mapping from FX nodes to TensorRT tensors. + + Returns: + TensorRT output tensor. + """ + try: + import tensorrt as trt + import numpy as np + except ImportError as e: + raise ImportError("TensorRT and numpy are required for convert_rsub_scalar") from e + + logger.debug(f"[TensorRT] Converting rsub.Scalar node: {node.name}") + + args = node.args + + input_node = args[0] + other = args[1] # Scalar to subtract from + alpha = args[2] if len(args) > 2 else 1.0 + + if not isinstance(input_node, torch.fx.Node): + raise ValueError(f"Input to rsub must be a node, got {type(input_node)}") + + if input_node not in input_map: + raise ValueError(f"Input node {input_node.name} not found in input_map") + + input_trt = input_map[input_node] + + # If alpha != 1, first multiply input by alpha + if alpha != 1.0: + alpha_weights = trt.Weights(np.array([alpha], dtype=np.float32)) + alpha_const = network.add_constant([1], alpha_weights) + alpha_const.name = f"rsub_alpha_const_{node.name}" + + mul_layer = network.add_elementwise( + input_trt, alpha_const.get_output(0), trt.ElementWiseOperation.PROD + ) + mul_layer.name = f"rsub_mul_alpha_{node.name}" + input_trt = mul_layer.get_output(0) + + # Create constant for 'other' scalar + other_weights = trt.Weights(np.array([other], dtype=np.float32)) + other_const = network.add_constant([1], other_weights) + other_const.name = f"rsub_other_const_{node.name}" + + # Compute: other - input (reverse subtract) + layer = network.add_elementwise( + other_const.get_output(0), input_trt, trt.ElementWiseOperation.SUB + ) + + if layer is None: + raise RuntimeError(f"Failed to create elementwise layer for rsub {node.name}") + + layer.name = f"rsub_scalar_{node.name}" + + logger.debug( + f"[TensorRT] Created rsub.Scalar layer: {layer.name}, other={other}, alpha={alpha}" + ) + + return layer.get_output(0) + + +__all__ = [ + "convert_slice", + "convert_index", + "convert_contiguous", + "convert_unflatten", + "convert_rsub_scalar", +] diff --git a/backends/nvidia/tensorrt/converters/targets.bzl b/backends/nvidia/tensorrt/converters/targets.bzl index 00648c5e49a..23ce33fdb22 100644 --- a/backends/nvidia/tensorrt/converters/targets.bzl +++ b/backends/nvidia/tensorrt/converters/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): "batch_norm.py", "bmm.py", "clamp.py", + "comparison.py", "concat.py", "conv2d.py", "dim_order_ops.py", @@ -35,6 +36,7 @@ def define_common_targets(): "relu.py", "reshape.py", "sdpa.py", + "slice.py", "sub.py", "upsample.py", ], diff --git a/backends/nvidia/tensorrt/partitioner/operator_support.py b/backends/nvidia/tensorrt/partitioner/operator_support.py index d565e2e3998..fb216c9fd6a 100644 --- a/backends/nvidia/tensorrt/partitioner/operator_support.py +++ b/backends/nvidia/tensorrt/partitioner/operator_support.py @@ -24,17 +24,108 @@ class TensorRTOperatorSupport(OperatorSupportBase): # Operations that have TensorRT converters (sorted alphabetically). SUPPORTED_OPS: Set[str] = { + "_log_softmax.default", + "_native_batch_norm_legit.default", + "_native_batch_norm_legit_no_training.default", + "_scaled_dot_product_efficient_attention.default", + "_scaled_dot_product_flash_attention.default", + "_softmax.default", + "_unsafe_view.default", + "adaptive_avg_pool2d.default", "add.Tensor", "add_.Tensor", + "addmm.default", + "alias_copy.default", + "any.dim", + "arange.start_step", + "avg_pool2d.default", + "batch_norm.default", + "bmm.default", + "cat.default", + "chunk.default", + "clamp.default", + "clamp_max.default", + "clamp_min.default", + "clone.default", + "constant_pad_nd.default", + "contiguous.default", + "copy.default", + "conv2d.default", + "convolution.default", "div.Tensor", "div.Tensor_mode", + "dropout.default", + "dropout_.default", + "embedding.default", + "eq.Scalar", + "expand.default", + "expand_copy.default", + "flatten.using_ints", + "full.default", + "full_like.default", + "ge.Scalar", + "gelu.default", + "gt.Scalar", + "hardsigmoid.default", + "hardswish.default", + "hardswish_.default", + "hardtanh.default", + "hardtanh_.default", + "index.Tensor", + "layer_norm.default", + "le.Scalar", + "linear.default", + "log_softmax.int", + "logical_not.default", + "lt.Scalar", + "max_pool2d.default", + "max_pool2d_with_indices.default", + "mean.dim", "mm.default", "mul.Scalar", "mul.Tensor", "mul_.Tensor", + "native_layer_norm.default", + "ne.Scalar", + "ones_like.default", + "permute.default", + "permute_copy.default", + "pixel_shuffle.default", "relu.default", "relu_.default", + "repeat.default", + "reshape.default", + "rsub.Scalar", + "scalar_tensor.default", + "scaled_dot_product_attention.default", + "select.int", + "select_copy.int", + "sigmoid.default", + "silu.default", + "slice.Tensor", + "slice_copy.Tensor", + "softmax.int", + "split.Tensor", + "split_with_sizes.default", + "split_with_sizes_copy.default", + "squeeze.dim", + "squeeze.dims", + "squeeze_copy.dim", + "squeeze_copy.dims", + "stack.default", "sub.Tensor", + "tanh.default", + "transpose.int", + "unflatten.int", + "unsqueeze.default", + "unsqueeze_copy.default", + "upsample_bilinear2d.vec", + "upsample_nearest2d.vec", + "view.default", + "view_copy.default", + "where.ScalarSelf", + "where.self", + "zeros_like.default", } # Glue operations that don't compute but are needed to keep partitions connected. @@ -61,6 +152,7 @@ class TensorRTOperatorSupport(OperatorSupportBase): torch.bool, torch.bfloat16, torch.float32, + torch.int64, } def is_node_supported(self, submodules: dict, node: torch.fx.Node) -> bool: diff --git a/examples/models/mobilenet_v2/model.py b/examples/models/mobilenet_v2/model.py index 32e82197e46..545a8931c92 100644 --- a/examples/models/mobilenet_v2/model.py +++ b/examples/models/mobilenet_v2/model.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +import os import torch @@ -36,10 +37,11 @@ def get_example_inputs(self): "https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg", ) - try: - urllib.URLopener().retrieve(url, filename) - except: - urllib.request.urlretrieve(url, filename) + if not os.path.exists(filename): + try: + urllib.URLopener().retrieve(url, filename) + except Exception: + urllib.request.urlretrieve(url, filename) from PIL import Image from torchvision import transforms diff --git a/examples/nvidia/tensorrt/export.py b/examples/nvidia/tensorrt/export.py index a3b823cf402..8099f376280 100644 --- a/examples/nvidia/tensorrt/export.py +++ b/examples/nvidia/tensorrt/export.py @@ -38,15 +38,18 @@ "conv1d", "dl3", "edsr", - # "efficient_sam", # TODO: diff ~41 — likely bicubic interpolation decomposition or ConvTranspose2d issue + "efficient_sam", "emformer_join", # "emformer_predict", # TODO: passes 1/3 seeds — precision sensitive with randomized inputs "emformer_transcribe", "ic3", "linear", "mul", + "mv2", "mv3", "sdpa", + "resnet18", + "resnet50", "softmax", "w2l", } diff --git a/examples/nvidia/tensorrt/tests/TARGETS b/examples/nvidia/tensorrt/tests/TARGETS index 9bb885f8900..71c30ea2622 100644 --- a/examples/nvidia/tensorrt/tests/TARGETS +++ b/examples/nvidia/tensorrt/tests/TARGETS @@ -16,6 +16,39 @@ manifold_get( visibility = ["PUBLIC"], ) +manifold_get( + name = "dog_jpg", + out = "dog.jpg", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/dog.jpg", + bucket_name = "executorch", + sha1 = "cec849467bf5701d3f79311c7b564586b57d75bd", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + +manifold_get( + name = "mv2_weights", + out = "mobilenet_v2-b0353104.pth", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/mobilenet_v2-b0353104.pth", + bucket_name = "executorch", + sha1 = "9d6df55a618d1707f020679b8cd68c91d4dec003", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + +manifold_get( + name = "ic4_weights", + out = "inceptionv4-8e4777a0.pth", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/inceptionv4-8e4777a0.pth", + bucket_name = "executorch", + sha1 = "37267955de289b48105a459fe11ace3c697923f2", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + # Export correctness tests: exports each supported model with TensorRT # and compares inference outputs against eager PyTorch on GPU. # buck2 test fbcode//executorch/examples/nvidia/tensorrt/tests:test_export @@ -26,7 +59,10 @@ python_unittest_remote_gpu( env = { "HTTPS_PROXY": "http://fwdproxy.any:8080", "HTTP_PROXY": "http://fwdproxy.any:8080", + "DOG_JPG": "$(location :dog_jpg)", "EDSR_WEIGHTS": "$(location :edsr_weights)", + "IC4_WEIGHTS": "$(location :ic4_weights)", + "MV2_WEIGHTS": "$(location :mv2_weights)", }, keep_gpu_sections = True, labels = ["long_running"], diff --git a/examples/nvidia/tensorrt/tests/test_export.py b/examples/nvidia/tensorrt/tests/test_export.py index 79b57d65b4c..81d50b6782b 100644 --- a/examples/nvidia/tensorrt/tests/test_export.py +++ b/examples/nvidia/tensorrt/tests/test_export.py @@ -25,7 +25,10 @@ # The test TARGETS provides these via manifold_get + $(location). # Entries are added as models are enabled in later commits. _WEIGHT_ENV_VARS = { + "DOG_JPG": "dog.jpg", "EDSR_WEIGHTS": "edsr64_x2.pt", + "IC4_WEIGHTS": "inceptionv4-8e4777a0.pth", + "MV2_WEIGHTS": "mobilenet_v2-b0353104.pth", } @@ -40,7 +43,24 @@ def _populate_weight_cache() -> None: for env_var, filename in _WEIGHT_ENV_VARS.items(): src = os.environ.get(env_var) if src and os.path.isfile(src): - dst = os.path.join(cache_dir, filename) + if env_var == "DOG_JPG": + # MV2Model downloads dog.jpg to CWD + dst = os.path.join(os.getcwd(), filename) + elif env_var.startswith("MOBILEBERT_"): + # Pre-populate HuggingFace cache for mobilebert + hf_dir = os.path.join( + os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")), + "hub", "models--google--mobilebert-uncased", + "snapshots", "manifold", + ) + os.makedirs(hf_dir, exist_ok=True) + refs_dir = os.path.join(os.path.dirname(hf_dir), "refs") + os.makedirs(refs_dir, exist_ok=True) + with open(os.path.join(refs_dir, "main"), "w") as rf: + rf.write("manifold") + dst = os.path.join(hf_dir, filename) + else: + dst = os.path.join(cache_dir, filename) if not os.path.exists(dst): shutil.copy2(src, dst) logger.info(f"Cached {filename} from {src}") @@ -129,3 +149,12 @@ def test_mv3(self) -> None: def test_ic3(self) -> None: _export_and_verify("ic3") + + def test_mv2(self) -> None: + _export_and_verify("mv2") + + def test_resnet18(self) -> None: + _export_and_verify("resnet18") + + def test_resnet50(self) -> None: + _export_and_verify("resnet50")