diff --git a/backends/nvidia/tensorrt/README.md b/backends/nvidia/tensorrt/README.md index 471e07d4b86..4c3c8272e82 100644 --- a/backends/nvidia/tensorrt/README.md +++ b/backends/nvidia/tensorrt/README.md @@ -37,7 +37,7 @@ cmake -B cmake-out \ -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DCMAKE_BUILD_TYPE=Release -cmake --build cmake-out --target tensorrt_executor_runner -j$(nproc) +cmake --build cmake-out --target tensorrt_executor_runner benchmark -j$(nproc) ``` ### Export and Run Models @@ -49,6 +49,9 @@ python -m executorch.examples.nvidia.tensorrt.export -m add # Export all supported models python -m executorch.examples.nvidia.tensorrt.export +# Export with ONNX baseline (for benchmarking) +python -m executorch.examples.nvidia.tensorrt.export --onnx + # Run inference with the C++ runner ./cmake-out/backends/nvidia/tensorrt/tensorrt_executor_runner --model_path=add_tensorrt.pte ``` @@ -110,7 +113,7 @@ Alternatively, download and install from the | Option | Type | Default | Description | |--------|------|---------|-------------| | `workspace_size` | int | 1GB | TensorRT builder workspace size | -| `precision` | TensorRTPrecision | FP32 | Inference precision (FP32, FP16, INT8) | +| `precision` | TensorRTPrecision | FP32 | Inference precision (FP32, FP16, BF16, INT8) | | `strict_type_constraints` | bool | False | Enforce strict type constraints | | `max_batch_size` | int | 1 | Maximum batch size | | `device_id` | int | 0 | CUDA device ID | @@ -158,15 +161,9 @@ with open("model_tensorrt.pte", "wb") as f: ## Supported Operations | Category | Operations | -|----------|-----------| -| Elementwise | add, sub, mul, div, mm, relu | -| Reshape | view, reshape, squeeze, unsqueeze, permute, contiguous, clone | - -## Supported Operations - -| Category | Operations | -|----------|-----------| -| Elementwise | add, sub, mul, div, floor_divide, rsub | +|----------|------------| +| Elementwise | add, sub, mul, div, floor_divide, rsub, pow, abs, ceil, sqrt | +| Unary Math | cos, sin, exp, erf, log | | Matrix | mm, addmm, bmm, linear | | Convolution | conv2d | | Normalization | batch_norm, layer_norm | @@ -176,8 +173,9 @@ with open("model_tensorrt.pte", "wb") as f: | Reduction | mean, any | | Concat/Split | cat, split, chunk, stack | | Comparison | eq, ne, lt, le, gt, ge, where, logical_not | -| Slicing | slice, select, index | -| Other | embedding, expand, repeat, upsample, pixel_shuffle, scaled_dot_product_attention, full | +| Slicing | slice, select, index, arange | +| Padding | constant_pad_nd | +| Other | embedding, expand, repeat, upsample, pixel_shuffle, scaled_dot_product_attention, full, dropout | ## Jetson Deployment @@ -252,6 +250,21 @@ Each test exports a model with TensorRT, runs inference via ExecuTorch pybindings, and compares outputs against eager PyTorch (atol=1e-3, rtol=1e-3) across 3 random seeds. +### Benchmark + +```bash +# Benchmark all exported models in the current directory +./cmake-out/examples/nvidia/tensorrt/benchmark + +# Benchmark with options +./cmake-out/examples/nvidia/tensorrt/benchmark -d DIR -m MODEL -n 100 -w 5 +``` + +The benchmark reports three formats per model: +- **pte** — ExecuTorch end-to-end (includes framework overhead) +- **pte-raw** — Raw TRT engine execution extracted from the .pte +- **onnx-trt** — ONNX → TRT engine (baseline, when .onnx files are present) + ## Troubleshooting | Issue | Fix | diff --git a/backends/nvidia/tensorrt/backend.py b/backends/nvidia/tensorrt/backend.py index 5fdfeeb0d4a..67b8b2e5f18 100644 --- a/backends/nvidia/tensorrt/backend.py +++ b/backends/nvidia/tensorrt/backend.py @@ -224,18 +224,8 @@ def _add_params_to_input_map( if isinstance(param_tensor, torch.nn.Parameter): param_tensor = param_tensor.data if isinstance(param_tensor, torch.Tensor): - # Convert int64/int32 tensors to float32 for TensorRT compatibility - # These are often used in elementwise operations with float tensors - # (e.g., batch norm statistics in MobileNetV3) - original_dtype = param_tensor.dtype - if param_tensor.dtype in (torch.int32, torch.int64): - param_tensor = param_tensor.float() - logger.debug( - f"Converting param {node.name} from {original_dtype} to float32 " - f"for TensorRT compatibility" - ) - elif param_tensor.dtype == torch.float64: - param_tensor = param_tensor.float() + # get_trt_tensor handles dtype conversion (int64→int32, float64→float32) + # via create_constant in converter_utils.py input_map[node] = get_trt_tensor_fn( network, param_tensor, f"param_{node.name}" ) @@ -394,6 +384,9 @@ def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]: Returns: List of TensorRTIOBinding with input/output tensor metadata. """ + # Import here to avoid circular imports at module level + from executorch.backends.nvidia.tensorrt.converter_utils import get_safe_shape + bindings = [] # Collect inputs @@ -403,7 +396,7 @@ def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]: TensorRTIOBinding( name=tensor.name, dtype=_trt_dtype_to_string(tensor.dtype), - shape=list(tensor.shape), + shape=get_safe_shape(tensor), is_input=True, ) ) @@ -415,7 +408,7 @@ def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]: TensorRTIOBinding( name=tensor.name, dtype=_trt_dtype_to_string(tensor.dtype), - shape=list(tensor.shape), + shape=get_safe_shape(tensor), is_input=False, ) ) diff --git a/backends/nvidia/tensorrt/converter_utils.py b/backends/nvidia/tensorrt/converter_utils.py index 42b6b88ad5c..96b493d17b6 100644 --- a/backends/nvidia/tensorrt/converter_utils.py +++ b/backends/nvidia/tensorrt/converter_utils.py @@ -269,8 +269,11 @@ def get_trt_tensor( Handles: - TensorRT ITensor (returned as-is) - Python scalars (int, float) → constant tensor - - PyTorch tensors → constant tensor + - PyTorch tensors → constant tensor (including FakeTensors/subclasses) - numpy arrays → constant tensor + + Note: Uses unset_fake_temporarily to handle tensor subclasses like FakeTensor + that don't support .numpy() directly. This follows the TensorRT pattern. """ if isinstance(value, trt.ITensor): return value @@ -283,8 +286,32 @@ def get_trt_tensor( return create_constant(network, value, name) if isinstance(value, torch.Tensor): - value = _tensor_to_numpy(value) - return create_constant(network, value, name) + # Handle tensor subclasses (FakeTensor, etc.) that don't support .numpy() + # by temporarily exiting fake tensor mode. This follows TensorRT pattern. + try: + from torch.fx.experimental.proxy_tensor import unset_fake_temporarily + with unset_fake_temporarily(): + # Create a real tensor from the fake tensor's data if needed + if hasattr(value, '_local_scalar_dense') or not value.is_contiguous(): + value = value.contiguous() + np_value = value.detach().cpu().numpy() + except (ImportError, RuntimeError): + # Fallback: try to convert via creating a new tensor + try: + np_value = _tensor_to_numpy(value) + except RuntimeError: + # Last resort: create tensor from metadata if available + if hasattr(value, 'shape') and hasattr(value, 'dtype'): + # For FakeTensors, we may need to create a zero tensor as placeholder + # This should only happen during tracing, not actual execution + np_dtype = _torch_dtype_to_numpy(value.dtype) + np_value = np.zeros(tuple(value.shape), dtype=np_dtype) + else: + raise RuntimeError( + f"Cannot convert tensor subclass {type(value)} to numpy. " + f"Tensor may be a FakeTensor from tracing." + ) + return create_constant(network, np_value, name) if isinstance(value, np.ndarray): return create_constant(network, value, name) @@ -300,6 +327,8 @@ def create_constant( """Create a TensorRT constant tensor from numpy array. Note: TensorRT doesn't support int64 (i64), so we convert to int32. + Also, TensorRT doesn't handle 0-d tensors well in elementwise ops, + so we reshape scalars to 1-d tensors with shape (1,). """ # TensorRT doesn't support int64 - convert to int32 if value.dtype == np.int64: @@ -308,11 +337,39 @@ def create_constant( if value.dtype == np.float64: value = value.astype(np.float32) + # TensorRT requires at least 1-d tensors for elementwise ops. + # Reshape 0-d scalars to 1-d tensors with shape (1,). + if value.ndim == 0: + value = value.reshape((1,)) + layer = network.add_constant(value.shape, trt.Weights(value)) layer.name = f"const_{name}" return layer.get_output(0) +def get_safe_shape(tensor: trt.ITensor) -> List[int]: + """Get tensor shape safely, handling dynamic shapes. + + TensorRT tensors can have invalid shapes during network building + (e.g., negative length for dynamic dimensions). This function + safely extracts the shape as a list. + + Args: + tensor: TensorRT tensor to get shape from. + + Returns: + List of dimension sizes, or empty list if shape is invalid. + """ + try: + shape = tensor.shape + if shape is None: + return [] + shape_list = list(shape) + return shape_list + except (ValueError, TypeError): + return [] + + def broadcast_tensors( network: trt.INetworkDefinition, tensors: Sequence[trt.ITensor], @@ -335,10 +392,13 @@ def broadcast_tensors( """ result = [] for i, tensor in enumerate(tensors): - current_ndim = len(tensor.shape) + shape = get_safe_shape(tensor) + current_ndim = len(shape) if shape else target_ndim + if current_ndim < target_ndim: diff = target_ndim - current_ndim - new_shape = (1,) * diff + tuple(tensor.shape) + existing_shape = tuple(shape) if shape else tuple([-1] * current_ndim) + new_shape = (1,) * diff + existing_shape layer = network.add_shuffle(tensor) layer.reshape_dims = new_shape # Use context counter if available, otherwise use simple naming diff --git a/backends/nvidia/tensorrt/converters/__init__.py b/backends/nvidia/tensorrt/converters/__init__.py index 2f99cc6f5ca..d3a25863205 100644 --- a/backends/nvidia/tensorrt/converters/__init__.py +++ b/backends/nvidia/tensorrt/converters/__init__.py @@ -28,12 +28,14 @@ from executorch.backends.nvidia.tensorrt.converters import permute_copy # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import pixel_shuffle # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import pooling # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import power # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import reduction # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import relu # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import reshape # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import sdpa # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import slice # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import sub # noqa: F401 +from executorch.backends.nvidia.tensorrt.converters import unary # noqa: F401 from executorch.backends.nvidia.tensorrt.converters import upsample # noqa: F401 diff --git a/backends/nvidia/tensorrt/converters/activations.py b/backends/nvidia/tensorrt/converters/activations.py index 5029ea52f6a..e40f4715086 100644 --- a/backends/nvidia/tensorrt/converters/activations.py +++ b/backends/nvidia/tensorrt/converters/activations.py @@ -32,6 +32,7 @@ import torch from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import get_node_shape logger: logging.Logger = logging.getLogger(__name__) @@ -489,8 +490,8 @@ def convert_log_softmax( input_trt = input_map[input_node] - # Handle negative dimension - input_shape = input_trt.shape + # Handle negative dimension using TRT tensor shape. + input_shape = tuple(input_trt.shape) ndim = len(input_shape) if dim < 0: dim = ndim + dim diff --git a/backends/nvidia/tensorrt/converters/add.py b/backends/nvidia/tensorrt/converters/add.py index 05486a96fb0..b008b8b0f08 100644 --- a/backends/nvidia/tensorrt/converters/add.py +++ b/backends/nvidia/tensorrt/converters/add.py @@ -6,6 +6,7 @@ """Converter for element-wise addition operations.""" +import logging from typing import Any, Dict, Optional import tensorrt as trt @@ -20,6 +21,9 @@ ) +logger: logging.Logger = logging.getLogger(__name__) + + def _get_input_ndim(arg: Any, input_map: Dict[torch.fx.Node, Any]) -> int: """Get the number of dimensions for an elementwise input argument. @@ -36,13 +40,18 @@ def _get_input_ndim(arg: Any, input_map: Dict[torch.fx.Node, Any]) -> int: # Try to get ndim from node metadata first (most reliable) if "val" in arg.meta and hasattr(arg.meta["val"], "shape"): return len(arg.meta["val"].shape) - # Fall back to TRT tensor shape + # Fall back to TRT tensor shape - handle dynamic shapes carefully if arg in input_map: trt_tensor = input_map[arg] - shape = trt_tensor.shape - if shape is not None: - return len(shape) - # For scalars, return 0 (will be broadcast) + try: + shape = trt_tensor.shape + if shape is not None: + ndim = len(shape) + if ndim >= 0: # Valid shape + return ndim + except (ValueError, TypeError): + pass # Dynamic shape, fall through + # For scalars or unknown, return 0 (will be broadcast) return 0 @@ -55,6 +64,11 @@ def _get_elementwise_input( ) -> trt.ITensor: """Get TensorRT tensor for an elementwise operation input. + Handles: + - FX nodes already in input_map + - FX nodes that are lifted buffers/parameters (placeholder nodes with b_ or p_ prefix) + - Scalar values + Args: network: TensorRT network definition. input_map: Mapping from FX nodes to TensorRT tensors. @@ -66,15 +80,27 @@ def _get_elementwise_input( TensorRT tensor for the input. Raises: - ValueError: If arg is a Node but not found in input_map. + ValueError: If arg is a Node but not found in input_map and cannot be created as constant. """ if isinstance(arg, torch.fx.Node): - if arg not in input_map: - raise ValueError( - f"Input node '{arg.name}' not found in input_map. " - f"Available nodes: {list(input_map.keys())}" - ) - return input_map[arg] + if arg in input_map: + return input_map[arg] + + # Handle lifted buffers and parameters that aren't in input_map + # These are placeholder nodes with names starting with b_ (buffers) or p_ (parameters) + # or get_attr nodes. We need to create constants from their metadata values. + if arg.op == "placeholder" or arg.op == "get_attr": + if "val" in arg.meta and isinstance(arg.meta["val"], torch.Tensor): + logger.debug(f"[TensorRT] Creating constant for lifted buffer/parameter: {arg.name}") + trt_tensor = get_trt_tensor(network, arg.meta["val"], f"const_{arg.name}", dtype) + input_map[arg] = trt_tensor # Cache for future use + return trt_tensor + + raise ValueError( + f"Input node '{arg.name}' not found in input_map and could not be created as constant. " + f"Node op: {arg.op}, target: {arg.target}. " + f"Available nodes: {list(n.name for n in input_map.keys())}" + ) return get_trt_tensor(network, arg, name, dtype) diff --git a/backends/nvidia/tensorrt/converters/comparison.py b/backends/nvidia/tensorrt/converters/comparison.py index e3b32767a04..9e30e65b76a 100644 --- a/backends/nvidia/tensorrt/converters/comparison.py +++ b/backends/nvidia/tensorrt/converters/comparison.py @@ -81,7 +81,16 @@ def convert_eq_scalar( # Use shape [1, 1, ...] with same ndim as input for proper broadcasting ndim = len(input_shape) const_shape = [1] * ndim if ndim > 0 else [1] - other_data = np.full(const_shape, other, dtype=np.float32) + + # Match the dtype of the input tensor for TensorRT compatibility + # TensorRT elementwise operations require matching types + input_dtype = input_trt.dtype + if input_dtype == trt.int64 or input_dtype == trt.int32: + np_dtype = np.int64 if input_dtype == trt.int64 else np.int32 + else: + np_dtype = np.float32 + + other_data = np.full(const_shape, other, dtype=np_dtype) other_weights = trt.Weights(other_data) other_const = network.add_constant(const_shape, other_weights) other_const.name = f"eq_const_{node.name}" @@ -121,8 +130,15 @@ def convert_ne_scalar( input_trt = input_map[input_node] + # Match the dtype of the input tensor for TensorRT compatibility + input_dtype = input_trt.dtype + if input_dtype == trt.int64 or input_dtype == trt.int32: + np_dtype = np.int64 if input_dtype == trt.int64 else np.int32 + else: + np_dtype = np.float32 + # Create constant for scalar - other_weights = trt.Weights(np.array([other], dtype=np.float32)) + other_weights = trt.Weights(np.array([other], dtype=np_dtype)) other_const = network.add_constant([1], other_weights) other_const.name = f"ne_const_{node.name}" @@ -165,7 +181,14 @@ def convert_lt_scalar( input_trt = input_map[input_node] - other_weights = trt.Weights(np.array([other], dtype=np.float32)) + # Match the dtype of the input tensor for TensorRT compatibility + input_dtype = input_trt.dtype + if input_dtype == trt.int64 or input_dtype == trt.int32: + np_dtype = np.int64 if input_dtype == trt.int64 else np.int32 + else: + np_dtype = np.float32 + + other_weights = trt.Weights(np.array([other], dtype=np_dtype)) other_const = network.add_constant([1], other_weights) other_const.name = f"lt_const_{node.name}" @@ -203,7 +226,14 @@ def convert_gt_scalar( input_trt = input_map[input_node] - other_weights = trt.Weights(np.array([other], dtype=np.float32)) + # Match the dtype of the input tensor for TensorRT compatibility + input_dtype = input_trt.dtype + if input_dtype == trt.int64 or input_dtype == trt.int32: + np_dtype = np.int64 if input_dtype == trt.int64 else np.int32 + else: + np_dtype = np.float32 + + other_weights = trt.Weights(np.array([other], dtype=np_dtype)) other_const = network.add_constant([1], other_weights) other_const.name = f"gt_const_{node.name}" @@ -260,10 +290,17 @@ def convert_ge_scalar( except (TypeError, ValueError): input_shape = [1] + # Match the dtype of the input tensor for TensorRT compatibility + input_dtype = input_trt.dtype + if input_dtype == trt.int64 or input_dtype == trt.int32: + np_dtype = np.int64 if input_dtype == trt.int64 else np.int32 + else: + np_dtype = np.float32 + # Create constant with proper shape for broadcasting ndim = len(input_shape) const_shape = [1] * ndim if ndim > 0 else [1] - other_data = np.full(const_shape, other, dtype=np.float32) + other_data = np.full(const_shape, other, dtype=np_dtype) other_weights = trt.Weights(other_data) other_const = network.add_constant(const_shape, other_weights) other_const.name = f"ge_const_{node.name}" @@ -338,10 +375,17 @@ def convert_le_scalar( except (TypeError, ValueError): input_shape = [1] + # Match the dtype of the input tensor for TensorRT compatibility + input_dtype = input_trt.dtype + if input_dtype == trt.int64 or input_dtype == trt.int32: + np_dtype = np.int64 if input_dtype == trt.int64 else np.int32 + else: + np_dtype = np.float32 + # Create constant with proper shape for broadcasting ndim = len(input_shape) const_shape = [1] * ndim if ndim > 0 else [1] - other_data = np.full(const_shape, other, dtype=np.float32) + other_data = np.full(const_shape, other, dtype=np_dtype) other_weights = trt.Weights(other_data) other_const = network.add_constant(const_shape, other_weights) other_const.name = f"le_const_{node.name}" @@ -676,6 +720,10 @@ def convert_where( Convert PyTorch where to TensorRT select layer. where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + + Note: TensorRT select requires the condition to be boolean type. + If the condition is not boolean, we convert it by comparing != 0. + Also handles broadcasting for scalar inputs. """ try: import tensorrt as trt @@ -695,19 +743,93 @@ def convert_where( self_trt = input_map[self_node] if isinstance(self_node, torch.fx.Node) else None other_trt = input_map[other_node] if isinstance(other_node, torch.fx.Node) else None - # Handle scalar inputs + # Get condition shape for broadcasting reference + cond_shape = get_node_shape(condition_node) + if cond_shape is None: + try: + cond_shape = tuple(condition_trt.shape) + except (ValueError, TypeError): + cond_shape = None + + max_ndim = len(cond_shape) if cond_shape else 0 + + # Handle scalar and tensor inputs - track shapes for broadcasting import numpy as np - if self_trt is None: - self_weights = trt.Weights(np.array([self_node], dtype=np.float32)) - self_const = network.add_constant([1], self_weights) - self_const.name = f"where_self_const_{node.name}" - self_trt = self_const.get_output(0) - - if other_trt is None: - other_weights = trt.Weights(np.array([other_node], dtype=np.float32)) - other_const = network.add_constant([1], other_weights) - other_const.name = f"where_other_const_{node.name}" - other_trt = other_const.get_output(0) + + def get_input_trt_and_shape(node_or_val, input_trt, name_suffix): + """Get TRT tensor and its shape, handling scalars.""" + if input_trt is not None: + # It's already a tensor + if isinstance(node_or_val, torch.fx.Node): + shape = get_node_shape(node_or_val) + else: + shape = None + if shape is None: + try: + shape = tuple(input_trt.shape) + except (ValueError, TypeError): + shape = (1,) + return input_trt, shape + else: + # It's a scalar - create constant with shape [1] + val = float(node_or_val) if not isinstance(node_or_val, (int, float)) else node_or_val + weights = trt.Weights(np.array([val], dtype=np.float32)) + const = network.add_constant([1], weights) + const.name = f"where_{name_suffix}_const_{node.name}" + return const.get_output(0), (1,) + + self_trt, self_shape = get_input_trt_and_shape(self_node, self_trt, "self") + other_trt, other_shape = get_input_trt_and_shape(other_node, other_trt, "other") + + # Update max_ndim based on all inputs + max_ndim = max(max_ndim, len(self_shape), len(other_shape)) + + def prepend_ones_to_shape(tensor, tensor_shape, target_ndim, name_suffix): + """Prepend 1s to tensor shape for broadcasting.""" + current_ndim = len(tensor_shape) + if current_ndim < target_ndim: + diff = target_ndim - current_ndim + new_shape = (1,) * diff + tuple(tensor_shape) + shuffle = network.add_shuffle(tensor) + shuffle.reshape_dims = trt.Dims(new_shape) + shuffle.name = f"where_broadcast_{name_suffix}_{node.name}" + return shuffle.get_output(0) + return tensor + + # Broadcast all inputs to max_ndim + if cond_shape and len(cond_shape) < max_ndim: + condition_trt = prepend_ones_to_shape(condition_trt, cond_shape, max_ndim, "cond") + self_trt = prepend_ones_to_shape(self_trt, self_shape, max_ndim, "self") + other_trt = prepend_ones_to_shape(other_trt, other_shape, max_ndim, "other") + + # TensorRT select requires boolean condition + # If condition is not boolean, convert it by comparing != 0 + # Pattern from TensorRT: cast to float, then compare with 0 + if condition_trt.dtype != trt.bool: + # Cast condition to float32 first + cast_layer = network.add_identity(condition_trt) + cast_layer.set_output_type(0, trt.float32) + cast_layer.name = f"where_cast_cond_{node.name}" + float_condition = cast_layer.get_output(0) + + # Create zero constant for comparison with broadcast-compatible shape + zero_shape = [1] * max_ndim if max_ndim > 0 else [1] + zero_weights = trt.Weights(np.zeros(zero_shape, dtype=np.float32)) + zero_const = network.add_constant(zero_shape, zero_weights) + zero_const.name = f"where_zero_const_{node.name}" + + # Compare condition != 0 to get boolean (using EQUAL then NOT) + eq_layer = network.add_elementwise( + float_condition, + zero_const.get_output(0), + trt.ElementWiseOperation.EQUAL + ) + eq_layer.name = f"where_eq_zero_{node.name}" + + # NOT the result to get != 0 + not_layer = network.add_unary(eq_layer.get_output(0), trt.UnaryOperation.NOT) + not_layer.name = f"where_not_{node.name}" + condition_trt = not_layer.get_output(0) # TensorRT select: output = condition ? self : other layer = network.add_select(condition_trt, self_trt, other_trt) diff --git a/backends/nvidia/tensorrt/converters/concat.py b/backends/nvidia/tensorrt/converters/concat.py index e751e722b8b..473fad8b423 100644 --- a/backends/nvidia/tensorrt/converters/concat.py +++ b/backends/nvidia/tensorrt/converters/concat.py @@ -206,18 +206,41 @@ def convert_cat( # Convert all input nodes to TensorRT tensors trt_tensors = [] + input_nodes = [] for tensor_node in tensors: if not isinstance(tensor_node, torch.fx.Node): raise ValueError(f"Input must be node, got {type(tensor_node)}") if tensor_node not in input_map: raise ValueError(f"Input node {tensor_node.name} not found in input_map") trt_tensors.append(input_map[tensor_node]) + input_nodes.append(tensor_node) if len(trt_tensors) == 0: raise ValueError("cat requires at least one input tensor") - # Get number of dimensions from first input - ndim = len(trt_tensors[0].shape) + # Get number of dimensions from node metadata (more reliable for dynamic shapes) + # TRT tensor.shape can have invalid length for dynamic shapes during network building + ndim = None + for input_node in input_nodes: + shape = get_node_shape(input_node) + if shape is not None: + ndim = len(shape) + break + + # Fallback to TRT shape if no metadata available + if ndim is None: + try: + trt_shape = trt_tensors[0].shape + if trt_shape is not None: + ndim = len(trt_shape) + if ndim < 0: + ndim = None + except (ValueError, TypeError): + pass + + if ndim is None: + raise ValueError("Cannot determine number of dimensions for cat operation") + cat_dim = _get_positive_dim(cat_dim, ndim) # Create concatenation layer @@ -230,9 +253,15 @@ def convert_cat( output = layer.get_output(0) + # Safe output shape logging + try: + out_shape = list(output.shape) if output.shape is not None else "dynamic" + except (ValueError, TypeError): + out_shape = "dynamic" + logger.debug( f"[TensorRT] Created cat layer: {layer.name}, " - f"axis={cat_dim}, num_inputs={len(trt_tensors)}, output_shape={list(output.shape)}" + f"axis={cat_dim}, num_inputs={len(trt_tensors)}, output_shape={out_shape}" ) return output @@ -275,24 +304,49 @@ def convert_stack( # Convert all input nodes to TensorRT tensors trt_tensors = [] + input_nodes = [] for tensor_node in tensors: if not isinstance(tensor_node, torch.fx.Node): raise ValueError(f"Input must be node, got {type(tensor_node)}") if tensor_node not in input_map: raise ValueError(f"Input node {tensor_node.name} not found in input_map") trt_tensors.append(input_map[tensor_node]) + input_nodes.append(tensor_node) if len(trt_tensors) == 0: raise ValueError("stack requires at least one input tensor") - # Get number of dimensions from first input (output will have ndim + 1) - ndim = len(trt_tensors[0].shape) + # Get number of dimensions from node metadata (output will have ndim + 1) + ndim = None + input_shape = None + for input_node in input_nodes: + shape = get_node_shape(input_node) + if shape is not None: + ndim = len(shape) + input_shape = list(shape) + break + + # Fallback to TRT shape if no metadata available + if ndim is None: + try: + trt_shape = trt_tensors[0].shape + if trt_shape is not None: + ndim = len(trt_shape) + if ndim >= 0: + input_shape = list(trt_shape) + else: + ndim = None + except (ValueError, TypeError): + pass + + if ndim is None or input_shape is None: + raise ValueError("Cannot determine number of dimensions for stack operation") + stack_dim = _get_positive_dim(stack_dim, ndim + 1) # Unsqueeze each tensor at the stack dimension unsqueezed_tensors = [] for i, trt_tensor in enumerate(trt_tensors): - input_shape = list(trt_tensor.shape) # Build output shape with new dimension of size 1 output_shape = input_shape[:stack_dim] + [1] + input_shape[stack_dim:] @@ -317,9 +371,15 @@ def convert_stack( output = layer.get_output(0) + # Safe output shape logging + try: + out_shape = list(output.shape) if output.shape is not None else "dynamic" + except (ValueError, TypeError): + out_shape = "dynamic" + logger.debug( f"[TensorRT] Created stack layer: {layer.name}, " - f"dim={stack_dim}, num_inputs={len(trt_tensors)}, output_shape={list(output.shape)}" + f"dim={stack_dim}, num_inputs={len(trt_tensors)}, output_shape={out_shape}" ) return output diff --git a/backends/nvidia/tensorrt/converters/conv2d.py b/backends/nvidia/tensorrt/converters/conv2d.py index 9a4e8c74434..d3455c93ca3 100644 --- a/backends/nvidia/tensorrt/converters/conv2d.py +++ b/backends/nvidia/tensorrt/converters/conv2d.py @@ -43,9 +43,7 @@ def validate_convolution(node: torch.fx.Node) -> bool: return False if len(node.args) < 9: return False - # Transposed convolution not supported - if node.args[6]: - return False + # Both regular and transposed convolutions are now supported return True @@ -175,7 +173,10 @@ def convert_convolution( edge_program: Optional[Union[ExportedProgram, torch.fx.GraphModule]] = None, ctx: Any = None, ) -> Any: - """Convert PyTorch convolution operation to TensorRT convolution layer.""" + """Convert PyTorch convolution operation to TensorRT convolution layer. + + Supports both regular convolution and transposed convolution (deconvolution). + """ try: import tensorrt as trt import numpy as np @@ -190,11 +191,9 @@ def convert_convolution( padding = args[4] dilation = args[5] transposed = args[6] + _output_padding = args[7] # Not applied: TRT handles output size via padding/stride groups = args[8] - if transposed: - raise ValueError(f"Transposed convolution not supported for node {node.name}") - if not isinstance(input_node, torch.fx.Node) or input_node not in input_map: raise ValueError(f"Input node {input_node} not found in input_map") @@ -206,13 +205,103 @@ def convert_convolution( raise ValueError(f"Could not extract weight tensor for convolution node {node.name}") weight_np = weight_tensor.detach().cpu().numpy().astype(np.float32) - out_channels = weight_np.shape[0] if not weight_np.flags['C_CONTIGUOUS']: weight_np = np.ascontiguousarray(weight_np) + # Store weight to prevent GC before engine build completes + if not hasattr(convert_convolution, '_weight_storage'): + convert_convolution._weight_storage = [] + convert_convolution._weight_storage.append(weight_np) + is_conv1d = len(weight_np.shape) == 3 + if transposed: + # Transposed convolution (deconvolution) + # For transposed conv, weight shape is [in_channels, out_channels/groups, ...] + # (opposite of regular conv which is [out_channels, in_channels/groups, ...]) + out_channels = weight_np.shape[1] * groups + + if is_conv1d: + kernel_size = weight_np.shape[2] + input_shape = input_trt.shape + + # Expand to 4D for TensorRT + if len(input_shape) == 3: + shuffle_in = network.add_shuffle(input_trt) + shuffle_in.reshape_dims = trt.Dims([input_shape[0], input_shape[1], 1, input_shape[2]]) + shuffle_in.name = f"deconv1d_unsqueeze_{node.name}" + input_trt = shuffle_in.get_output(0) + + # Reshape weight to 4D: [in_ch, out_ch/groups, 1, kernel_size] + weight_4d = np.ascontiguousarray( + weight_np.reshape(weight_np.shape[0], weight_np.shape[1], 1, kernel_size) + ) + convert_convolution._weight_storage.append(weight_4d) + + layer = network.add_deconvolution_nd( + input_trt, + out_channels, + trt.Dims([1, kernel_size]), + trt.Weights(weight_4d), + ) + layer.stride_nd = trt.Dims([1, stride[0]]) + layer.padding_nd = trt.Dims([0, padding[0]]) + # Note: TensorRT doesn't support dilation for deconvolution in most versions + layer.num_groups = groups + + if bias_node is not None: + bias_tensor = _get_param_tensor(exp_prog, bias_node) + if bias_tensor is not None: + bias_np = np.ascontiguousarray( + bias_tensor.detach().cpu().numpy().astype(np.float32) + ) + convert_convolution._weight_storage.append(bias_np) + layer.bias = trt.Weights(bias_np) + + layer.name = f"deconv1d_{node.name}" + output = layer.get_output(0) + + # Squeeze back to 3D if needed + output_shape = output.shape + if len(output_shape) == 4 and output_shape[2] == 1: + shuffle_out = network.add_shuffle(output) + shuffle_out.reshape_dims = trt.Dims([output_shape[0], output_shape[1], output_shape[3]]) + shuffle_out.name = f"deconv1d_squeeze_{node.name}" + output = shuffle_out.get_output(0) + else: + # 2D transposed convolution + kernel_h = weight_np.shape[2] + kernel_w = weight_np.shape[3] + + layer = network.add_deconvolution_nd( + input_trt, + out_channels, + trt.Dims([kernel_h, kernel_w]), + trt.Weights(weight_np), + ) + layer.stride_nd = trt.Dims(list(stride)) + layer.padding_nd = trt.Dims(list(padding)) + layer.num_groups = groups + + if bias_node is not None: + bias_tensor = _get_param_tensor(exp_prog, bias_node) + if bias_tensor is not None: + bias_np = np.ascontiguousarray( + bias_tensor.detach().cpu().numpy().astype(np.float32) + ) + layer.bias = trt.Weights(bias_np) + convert_convolution._weight_storage.append(bias_np) + + layer.name = f"deconvolution_{node.name}" + output = layer.get_output(0) + + logger.debug(f"[TensorRT] Created transposed convolution layer: {layer.name}") + return output + + # Regular convolution (existing code path) + out_channels = weight_np.shape[0] + if is_conv1d: kernel_size = weight_np.shape[2] input_shape = input_trt.shape @@ -225,10 +314,6 @@ def convert_convolution( weight_4d = np.ascontiguousarray( weight_np.reshape(out_channels, weight_np.shape[1], 1, kernel_size) ) - - # Store weight to prevent GC before engine build completes - if not hasattr(convert_convolution, '_weight_storage'): - convert_convolution._weight_storage = [] convert_convolution._weight_storage.append(weight_4d) layer = network.add_convolution_nd( @@ -248,7 +333,6 @@ def convert_convolution( bias_np = np.ascontiguousarray( bias_tensor.detach().cpu().numpy().astype(np.float32) ) - # Store bias to prevent GC before engine build completes convert_convolution._weight_storage.append(bias_np) layer.bias = trt.Weights(bias_np) @@ -265,10 +349,7 @@ def convert_convolution( kernel_h = weight_np.shape[2] kernel_w = weight_np.shape[3] - # Store weight to prevent GC before engine build completes weight_np_contiguous = np.ascontiguousarray(weight_np) - if not hasattr(convert_convolution, '_weight_storage'): - convert_convolution._weight_storage = [] convert_convolution._weight_storage.append(weight_np_contiguous) layer = network.add_convolution_nd( diff --git a/backends/nvidia/tensorrt/converters/div.py b/backends/nvidia/tensorrt/converters/div.py index 6885c602a8d..080b5035af6 100644 --- a/backends/nvidia/tensorrt/converters/div.py +++ b/backends/nvidia/tensorrt/converters/div.py @@ -6,6 +6,7 @@ """Converter for element-wise division operations.""" +import logging from typing import Any, Dict, Optional import tensorrt as trt @@ -21,6 +22,9 @@ ) +logger: logging.Logger = logging.getLogger(__name__) + + def _get_elementwise_input( network: trt.INetworkDefinition, input_map: Dict[torch.fx.Node, Any], @@ -28,14 +32,32 @@ def _get_elementwise_input( name: str, dtype: Optional[torch.dtype], ) -> trt.ITensor: - """Get TensorRT tensor for an elementwise operation input.""" + """Get TensorRT tensor for an elementwise operation input. + + Handles: + - FX nodes already in input_map + - FX nodes that are lifted buffers/parameters (placeholder nodes with b_ or p_ prefix) + - Scalar values + """ if isinstance(arg, torch.fx.Node): - if arg not in input_map: - raise ValueError( - f"Input node '{arg.name}' not found in input_map. " - f"Available nodes: {list(input_map.keys())}" - ) - return input_map[arg] + if arg in input_map: + return input_map[arg] + + # Handle lifted buffers and parameters that aren't in input_map + # These are placeholder nodes with names starting with b_ (buffers) or p_ (parameters) + # or get_attr nodes. We need to create constants from their metadata values. + if arg.op == "placeholder" or arg.op == "get_attr": + if "val" in arg.meta and isinstance(arg.meta["val"], torch.Tensor): + logger.debug(f"[TensorRT] Creating constant for lifted buffer/parameter: {arg.name}") + trt_tensor = get_trt_tensor(network, arg.meta["val"], f"const_{arg.name}", dtype) + input_map[arg] = trt_tensor # Cache for future use + return trt_tensor + + raise ValueError( + f"Input node '{arg.name}' not found in input_map and could not be created as constant. " + f"Node op: {arg.op}, target: {arg.target}. " + f"Available nodes: {list(n.name for n in input_map.keys())}" + ) return get_trt_tensor(network, arg, name, dtype) diff --git a/backends/nvidia/tensorrt/converters/embedding.py b/backends/nvidia/tensorrt/converters/embedding.py index 9989ac498e7..593415097ae 100644 --- a/backends/nvidia/tensorrt/converters/embedding.py +++ b/backends/nvidia/tensorrt/converters/embedding.py @@ -15,7 +15,6 @@ import logging from typing import Any, Dict, Optional -import numpy as np import torch from executorch.backends.nvidia.tensorrt.converter_registry import converter from executorch.backends.nvidia.tensorrt.converter_utils import create_constant @@ -128,6 +127,14 @@ def convert_embedding( indices_trt = input_map[indices_node] + # TensorRT gather requires int32 or int64 indices + # Cast indices to int32 if they are not already integer type + if indices_trt.dtype not in (trt.int32, trt.int64): + logger.debug(f"[TensorRT] Casting embedding indices from {indices_trt.dtype} to int32") + cast_layer = network.add_cast(indices_trt, trt.int32) + cast_layer.name = f"embedding_indices_cast_{node.name}" + indices_trt = cast_layer.get_output(0) + weight_shape = weight_trt.shape indices_shape = indices_trt.shape diff --git a/backends/nvidia/tensorrt/converters/mul.py b/backends/nvidia/tensorrt/converters/mul.py index 54a6f9ceec2..b10095b8c58 100644 --- a/backends/nvidia/tensorrt/converters/mul.py +++ b/backends/nvidia/tensorrt/converters/mul.py @@ -6,6 +6,7 @@ """Converter for element-wise multiplication operations.""" +import logging from typing import Any, Dict, Optional import numpy as np @@ -21,6 +22,9 @@ ) +logger: logging.Logger = logging.getLogger(__name__) + + def _get_elementwise_input( network: trt.INetworkDefinition, input_map: Dict[torch.fx.Node, Any], @@ -28,14 +32,32 @@ def _get_elementwise_input( name: str, dtype: Optional[torch.dtype], ) -> trt.ITensor: - """Get TensorRT tensor for an elementwise operation input.""" + """Get TensorRT tensor for an elementwise operation input. + + Handles: + - FX nodes already in input_map + - FX nodes that are lifted buffers/parameters (placeholder nodes with b_ or p_ prefix) + - Scalar values + """ if isinstance(arg, torch.fx.Node): - if arg not in input_map: - raise ValueError( - f"Input node '{arg.name}' not found in input_map. " - f"Available nodes: {list(input_map.keys())}" - ) - return input_map[arg] + if arg in input_map: + return input_map[arg] + + # Handle lifted buffers and parameters that aren't in input_map + # These are placeholder nodes with names starting with b_ (buffers) or p_ (parameters) + # or get_attr nodes. We need to create constants from their metadata values. + if arg.op == "placeholder" or arg.op == "get_attr": + if "val" in arg.meta and isinstance(arg.meta["val"], torch.Tensor): + logger.debug(f"[TensorRT] Creating constant for lifted buffer/parameter: {arg.name}") + trt_tensor = get_trt_tensor(network, arg.meta["val"], f"const_{arg.name}", dtype) + input_map[arg] = trt_tensor # Cache for future use + return trt_tensor + + raise ValueError( + f"Input node '{arg.name}' not found in input_map and could not be created as constant. " + f"Node op: {arg.op}, target: {arg.target}. " + f"Available nodes: {list(n.name for n in input_map.keys())}" + ) return get_trt_tensor(network, arg, name, dtype) diff --git a/backends/nvidia/tensorrt/converters/power.py b/backends/nvidia/tensorrt/converters/power.py new file mode 100644 index 00000000000..2d3d493a281 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/power.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""TensorRT Converters for Power Operations. + +Supported operations: +- aten.pow.Tensor_Scalar: Tensor raised to scalar power +- aten.pow.Tensor_Tensor: Tensor raised to tensor power + +These operations are commonly used in transformer models for layer normalization +decomposition and other mathematical computations. + +Design patterns follow TensorRT best practices including: +- Type promotion to ensure consistent dtypes for POW operation +- Proper type casting for integer inputs (POW supports float32/int8 only) +- Proper broadcasting using broadcast_tensors utility +""" + +import logging +from typing import Any, Dict, Optional + +import numpy as np +import tensorrt as trt +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter +from executorch.backends.nvidia.tensorrt.converter_utils import ( + broadcast_tensors, + create_constant, + get_node_shape, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_pow(node: torch.fx.Node) -> bool: + """Validate that a pow node can be converted to TensorRT.""" + if node.op != "call_function": + return False + if len(node.args) < 2: + return False + if not isinstance(node.args[0], torch.fx.Node): + return False + return True + + +def _ensure_float_dtype( + network: trt.INetworkDefinition, + tensor: trt.ITensor, + node_name: str, + suffix: str, +) -> trt.ITensor: + """Ensure tensor is float type for POW operation. + + TensorRT POW operation supports only float32, float16, and int8. + This follows TensorRT pattern of promoting types. + + Args: + network: TensorRT network definition. + tensor: Input tensor to potentially cast. + node_name: Node name for layer naming. + suffix: Suffix for layer naming. + + Returns: + Tensor with float dtype. + """ + if tensor.dtype in (trt.int32, trt.int64): + cast_layer = network.add_cast(tensor, trt.float32) + if cast_layer is None: + raise RuntimeError(f"Failed to create cast layer for {node_name}") + cast_layer.name = f"pow_cast_{suffix}_{node_name}" + return cast_layer.get_output(0) + return tensor + + +def _get_dtype_for_scalar(base_tensor: trt.ITensor) -> np.dtype: + """Get appropriate numpy dtype for scalar exponent based on base tensor type. + + Following TensorRT pattern where scalar inherits dtype from tensor. + """ + if base_tensor.dtype == trt.float16: + return np.float16 + elif base_tensor.dtype == trt.float32: + return np.float32 + else: + return np.float32 + + +@converter("aten.pow.Tensor_Scalar", validator_fn=validate_pow) +def convert_pow_tensor_scalar( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch pow.Tensor_Scalar to TensorRT. + + PyTorch signature: aten.pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + Computes self ** exponent element-wise. + + Following TensorRT patterns: + - Type promotion for consistent dtypes + - Proper scalar to tensor conversion + - Broadcasting to match input dimensions + """ + logger.debug(f"[TensorRT] Converting pow.Tensor_Scalar node: {node.name}") + + input_node = node.args[0] + exponent = node.args[1] + + if input_node not in input_map: + raise ValueError(f"Input node '{input_node.name}' not found in input_map") + + input_trt = input_map[input_node] + input_trt = _ensure_float_dtype(network, input_trt, node.name, "input") + + exponent_value = float(exponent) + exponent_dtype = _get_dtype_for_scalar(input_trt) + exponent_np = np.array([exponent_value], dtype=exponent_dtype) + exponent_trt = create_constant(network, exponent_np, f"pow_exp_{node.name}") + + # Get target ndim for broadcasting from input node metadata or tensor + input_shape = get_node_shape(input_node) + if input_shape is not None: + target_ndim = len(input_shape) + else: + target_ndim = len(input_trt.shape) + target_ndim = max(target_ndim, 1) + + # Broadcast exponent to match input dimensions for TensorRT elementwise op + [exponent_trt] = broadcast_tensors( + network, [exponent_trt], target_ndim, f"pow_exp_{node.name}" + ) + + layer = network.add_elementwise( + input_trt, exponent_trt, trt.ElementWiseOperation.POW + ) + if layer is None: + raise RuntimeError(f"Failed to create pow layer for {node.name}") + layer.name = f"pow_{node.name}" + + return layer.get_output(0) + + +@converter("aten.pow.Tensor_Tensor", validator_fn=validate_pow) +def convert_pow_tensor_tensor( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch pow.Tensor_Tensor to TensorRT. + + PyTorch signature: aten.pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + Computes self ** exponent element-wise. + + Following TensorRT patterns: + - Type promotion for both operands + - Proper handling of constant exponents + - Broadcasting to match dimensions + """ + logger.debug(f"[TensorRT] Converting pow.Tensor_Tensor node: {node.name}") + + input_node = node.args[0] + exponent_node = node.args[1] + + if input_node not in input_map: + raise ValueError(f"Input node '{input_node.name}' not found in input_map") + + input_trt = input_map[input_node] + input_trt = _ensure_float_dtype(network, input_trt, node.name, "input") + + if isinstance(exponent_node, torch.fx.Node): + if exponent_node not in input_map: + raise ValueError( + f"Exponent node '{exponent_node.name}' not found in input_map" + ) + exponent_trt = input_map[exponent_node] + exponent_trt = _ensure_float_dtype(network, exponent_trt, node.name, "exp") + else: + exponent_value = float(exponent_node) + exponent_dtype = _get_dtype_for_scalar(input_trt) + exponent_np = np.array([exponent_value], dtype=exponent_dtype) + exponent_trt = create_constant(network, exponent_np, f"pow_exp_{node.name}") + + # Determine target ndim for broadcasting + input_shape = get_node_shape(input_node) + if input_shape is not None: + input_ndim = len(input_shape) + else: + input_ndim = len(input_trt.shape) + + exp_ndim = len(exponent_trt.shape) + target_ndim = max(input_ndim, exp_ndim, 1) + + # Broadcast both tensors to target dimensions + [input_trt, exponent_trt] = broadcast_tensors( + network, [input_trt, exponent_trt], target_ndim, f"pow_{node.name}" + ) + + layer = network.add_elementwise( + input_trt, exponent_trt, trt.ElementWiseOperation.POW + ) + if layer is None: + raise RuntimeError(f"Failed to create pow layer for {node.name}") + layer.name = f"pow_{node.name}" + + return layer.get_output(0) + + +__all__ = [ + "convert_pow_tensor_scalar", + "convert_pow_tensor_tensor", + "validate_pow", +] diff --git a/backends/nvidia/tensorrt/converters/slice.py b/backends/nvidia/tensorrt/converters/slice.py index f6b146eac70..9218f4ba8d7 100644 --- a/backends/nvidia/tensorrt/converters/slice.py +++ b/backends/nvidia/tensorrt/converters/slice.py @@ -25,12 +25,13 @@ import logging import sys -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import torch from executorch.backends.nvidia.tensorrt.converter_registry import converter from executorch.backends.nvidia.tensorrt.converter_utils import ( get_node_shape, + get_trt_tensor, ) logger: logging.Logger = logging.getLogger(__name__) @@ -77,10 +78,25 @@ def convert_slice( if not isinstance(input_node, torch.fx.Node): raise ValueError(f"Input to slice must be a node, got {type(input_node)}") + # Handle case where input node is not in input_map (e.g., get_attr or lifted buffer) if input_node not in input_map: - raise ValueError(f"Input node {input_node.name} not found in input_map") - - input_trt = input_map[input_node] + # Try to get the value from node metadata or create a constant + input_val = None + if "val" in input_node.meta and isinstance(input_node.meta["val"], torch.Tensor): + input_val = input_node.meta["val"] + + if input_val is not None: + input_trt = get_trt_tensor(network, input_val, f"const_{input_node.name}") + input_map[input_node] = input_trt # Cache for future use + logger.debug(f"[TensorRT] Created constant tensor for {input_node.name} from metadata") + else: + raise ValueError( + f"Input node {input_node.name} not found in input_map and no metadata value available. " + f"This may be a lifted buffer that wasn't properly added. " + f"Node op: {input_node.op}, target: {input_node.target}" + ) + else: + input_trt = input_map[input_node] # Get shape from node metadata for reliability (TRT shapes can be invalid during network building) input_shape = list(get_node_shape(input_node) or input_trt.shape) @@ -162,7 +178,11 @@ def convert_index( index.Tensor(Tensor self, Tensor?[] indices) This operation indexes into a tensor using tensor indices. - For simple 1D indexing along first dimension, we use gather. + Supports both single-dimension and multi-dimension indexing (advanced indexing). + + For single index tensor: uses simple gather layer + For multiple index tensors: uses transpose + flatten + gather + reshape pattern + following TensorRT implementation. Args: node: FX node representing the index operation. @@ -174,6 +194,7 @@ def convert_index( """ try: import tensorrt as trt + import numpy as np except ImportError as e: raise ImportError("TensorRT is required for convert_index") from e @@ -192,36 +213,264 @@ def convert_index( input_trt = input_map[input_node] - # indices is a list of optional tensors, one per dimension - # Find the first non-None index - gather_dim = None - index_tensor = None + # Get input shape from metadata (more reliable for dynamic shapes) + input_shape = list(get_node_shape(input_node) or input_trt.shape) + rank = len(input_shape) + + # Collect non-None indices and their positions + adv_indx_indices = [] # Dimension indices where indexing is applied + tensor_indices = [] # The actual index tensors for i, idx in enumerate(indices): if idx is not None and isinstance(idx, torch.fx.Node): - if gather_dim is not None: + adv_indx_indices.append(i) + if idx not in input_map: + raise ValueError(f"Index node {idx.name} not found in input_map") + tensor_indices.append(input_map[idx]) + + if not tensor_indices: + # No valid index tensors - just return input cast to int32 + cast_layer = network.add_cast(input_trt, trt.int32) + cast_layer.name = f"index_casted_{node.name}" + return cast_layer.get_output(0) + + if len(tensor_indices) == 1: + # Simple single-dimension indexing - use gather directly + gather_dim = adv_indx_indices[0] + index_tensor = tensor_indices[0] + + # Ensure index is int32 for TensorRT gather + if index_tensor.dtype != trt.int32: + cast_layer = network.add_cast(index_tensor, trt.int32) + cast_layer.name = f"index_cast_{node.name}" + index_tensor = cast_layer.get_output(0) + + layer = network.add_gather(input_trt, index_tensor, axis=gather_dim) + if layer is None: + raise RuntimeError(f"Failed to create gather layer for index {node.name}") + layer.name = f"index_gather_{node.name}" + + logger.debug( + f"[TensorRT] Created simple index/gather layer: {layer.name}, gather_dim={gather_dim}" + ) + return layer.get_output(0) + + # Multiple index tensors - advanced indexing + # Follow TensorRT pattern: transpose -> flatten -> compute linear index -> gather -> reshape + logger.debug(f"[TensorRT] Advanced indexing with {len(tensor_indices)} index tensors at dims {adv_indx_indices}") + + adv_indx_count = len(adv_indx_indices) + + # Step 1: Transpose input to move indexed dimensions to the front + # new_order: [indexed dims..., non-indexed dims...] + new_order = adv_indx_indices.copy() + for i in range(rank): + if i not in adv_indx_indices: + new_order.append(i) + + transpose_layer = network.add_shuffle(input_trt) + transpose_layer.second_transpose = trt.Permutation(new_order) + transpose_layer.name = f"index_transpose_{node.name}" + transpose_tensor = transpose_layer.get_output(0) + + logger.debug(f"[TensorRT] Transpose order: {new_order}") + + # Step 2: Flatten the indexed dimensions into one, and non-indexed dims into another + # Result shape: [prod(indexed_dims), prod(non_indexed_dims)] + + # Compute products for reshape + mult_d0 = 1 # Product of indexed dimension sizes + for i in range(adv_indx_count): + dim_size = input_shape[adv_indx_indices[i]] + if isinstance(dim_size, int) and dim_size > 0: + mult_d0 *= dim_size + else: + # Dynamic dimension - fall back to simpler approach + raise NotImplementedError( + f"Dynamic shapes in indexed dimensions not fully supported. " + f"Dimension {adv_indx_indices[i]} has dynamic size." + ) + + mult_d1 = 1 # Product of non-indexed dimension sizes + for i in range(rank): + if i not in adv_indx_indices: + dim_size = input_shape[i] + if isinstance(dim_size, int) and dim_size > 0: + mult_d1 *= dim_size + else: raise NotImplementedError( - "Multiple index tensors not supported, only single dimension indexing" + f"Dynamic shapes in non-indexed dimensions not fully supported. " + f"Dimension {i} has dynamic size." ) - gather_dim = i - index_tensor = input_map.get(idx) - - if gather_dim is None or index_tensor is None: - raise ValueError("No valid index tensor found in indices") - - # Create gather layer - layer = network.add_gather(input_trt, index_tensor, axis=gather_dim) - if layer is None: - raise RuntimeError(f"Failed to create gather layer for index {node.name}") - - layer.name = f"index_gather_{node.name}" + # Create reshape to [mult_d0, mult_d1] + flatten_layer = network.add_shuffle(transpose_tensor) + flatten_layer.reshape_dims = trt.Dims([mult_d0, mult_d1]) + flatten_layer.name = f"index_flatten_{node.name}" + flatten_tensor = flatten_layer.get_output(0) + + logger.debug(f"[TensorRT] Flattened shape: [{mult_d0}, {mult_d1}]") + + # Step 3: Compute cumulative linear index following TensorRT formula + # tensor_index = sum_{i=1}^m (ind_i * prod_{j=i+1}^m (dim_j)) + # where ind_i is the i-th index tensor and dim_j is the size of the j-th indexed dimension + + # First, find the maximum number of dimensions across all index tensors + # TensorRT requires matching dimensions for elementwise operations + max_ndim = 1 + for idx_tensor in tensor_indices: + idx_ndim = len(idx_tensor.shape) + if idx_ndim > max_ndim: + max_ndim = idx_ndim + + logger.debug(f"[TensorRT] Max ndim across index tensors: {max_ndim}") + + # Helper function to ensure tensor has max_ndim dimensions by prepending 1s + def ensure_ndim(tensor: Any, name_suffix: str) -> Any: + """Reshape tensor to have max_ndim dimensions by prepending 1s.""" + current_ndim = len(tensor.shape) + if current_ndim < max_ndim: + # Need to prepend (max_ndim - current_ndim) dimensions of size 1 + new_shape = [1] * (max_ndim - current_ndim) + list(tensor.shape) + reshape_layer = network.add_shuffle(tensor) + reshape_layer.reshape_dims = trt.Dims(new_shape) + reshape_layer.name = f"index_reshape_{name_suffix}" + return reshape_layer.get_output(0) + return tensor + + # Start with the last index tensor (no multiplication needed for the last one) + cum_index = tensor_indices[adv_indx_count - 1] + + # Ensure int32 type + if cum_index.dtype != trt.int32: + cast_layer = network.add_cast(cum_index, trt.int32) + cast_layer.name = f"index_cast_last_{node.name}" + cum_index = cast_layer.get_output(0) + + # Ensure cum_index has max_ndim dimensions + cum_index = ensure_ndim(cum_index, f"last_{node.name}") + + # The multiplier accumulates the product of indexed dimension sizes + # Start with the size of the LAST indexed dimension + multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]] + + logger.debug(f"[TensorRT] Starting multiplier: {multiplier} (size of dim {adv_indx_indices[adv_indx_count - 1]})") + + # Process from second-to-last index tensor backwards to first + for i in range(adv_indx_count - 2, -1, -1): + idx_tensor = tensor_indices[i] + + # Ensure int32 type + if idx_tensor.dtype != trt.int32: + cast_layer = network.add_cast(idx_tensor, trt.int32) + cast_layer.name = f"index_cast_{i}_{node.name}" + idx_tensor = cast_layer.get_output(0) + + # Ensure idx_tensor has max_ndim dimensions + idx_tensor = ensure_ndim(idx_tensor, f"{i}_{node.name}") + + # Create multiplier constant with max_ndim dimensions for proper broadcasting + mult_shape = [1] * max_ndim + mult_const = network.add_constant( + mult_shape, + trt.Weights(np.array([multiplier], dtype=np.int32).reshape(mult_shape)) + ) + mult_const.name = f"index_mult_const_{i}_{node.name}" - logger.debug( - f"[TensorRT] Created index/gather layer: {layer.name}, gather_dim={gather_dim}" - ) + # adv_index = idx_tensor * multiplier + mul_layer = network.add_elementwise( + idx_tensor, + mult_const.get_output(0), + trt.ElementWiseOperation.PROD + ) + mul_layer.name = f"index_mul_{i}_{node.name}" - return layer.get_output(0) + # cum_index = cum_index + adv_index + add_layer = network.add_elementwise( + cum_index, + mul_layer.get_output(0), + trt.ElementWiseOperation.SUM + ) + add_layer.name = f"index_add_{i}_{node.name}" + cum_index = add_layer.get_output(0) + + # Update multiplier for next iteration: multiplier *= dim_size[current_indexed_dim] + dim_size = input_shape[adv_indx_indices[i]] + if isinstance(dim_size, int) and dim_size > 0: + multiplier *= dim_size + else: + raise NotImplementedError( + f"Dynamic shapes in indexed dimensions not fully supported. " + f"Dimension {adv_indx_indices[i]} has dynamic size." + ) + + logger.debug(f"[TensorRT] After index {i}: multiplier = {multiplier}") + + logger.debug(f"[TensorRT] Computed cumulative index") + + # Step 4: Gather using cumulative index on the flattened dimension 0 + gather_layer = network.add_gather(flatten_tensor, cum_index, axis=0) + if gather_layer is None: + raise RuntimeError(f"Failed to create gather layer for advanced index {node.name}") + gather_layer.name = f"index_gather_adv_{node.name}" + gather_out = gather_layer.get_output(0) + + logger.debug(f"[TensorRT] Gather output shape: {gather_out.shape}") + + # Step 5: Reshape output to match expected shape from node metadata + # The gather output shape is [cum_index_shape..., mult_d1] + # We reshape directly to the expected output shape from node metadata + + # Get expected output shape from node metadata (most reliable source) + expected_output_shape = get_node_shape(node) + + if expected_output_shape is not None: + output_shape = list(expected_output_shape) + logger.debug(f"[TensorRT] Using expected output shape from metadata: {output_shape}") + + # Reshape gather output directly to expected shape + reshape_layer = network.add_shuffle(gather_out) + reshape_layer.reshape_dims = trt.Dims(output_shape) + reshape_layer.name = f"index_final_reshape_{node.name}" + final_output = reshape_layer.get_output(0) + else: + # Fallback: Get index tensor shape from metadata and compute output shape + # Get the broadcast shape of all index tensors from their node metadata + idx_broadcast_shape = [] + for i, idx in enumerate(indices): + if idx is not None and isinstance(idx, torch.fx.Node): + idx_shape = get_node_shape(idx) + if idx_shape is not None: + idx_shape_list = list(idx_shape) + # Broadcast shapes - keep the max along each dimension + if not idx_broadcast_shape: + idx_broadcast_shape = idx_shape_list + else: + # Extend to match dimensions + while len(idx_broadcast_shape) < len(idx_shape_list): + idx_broadcast_shape.insert(0, 1) + while len(idx_shape_list) < len(idx_broadcast_shape): + idx_shape_list.insert(0, 1) + # Take max at each position + idx_broadcast_shape = [max(a, b) for a, b in zip(idx_broadcast_shape, idx_shape_list)] + + # Collect non-indexed dimensions + non_indexed_dims = [] + for i in range(rank): + if i not in adv_indx_indices: + non_indexed_dims.append(input_shape[i]) + + # Output shape: [idx_broadcast_shape..., non_indexed_dims...] + output_shape = idx_broadcast_shape + non_indexed_dims + logger.debug(f"[TensorRT] Computed output shape: {output_shape}") + + reshape_layer = network.add_shuffle(gather_out) + reshape_layer.reshape_dims = trt.Dims(output_shape) + reshape_layer.name = f"index_final_reshape_{node.name}" + final_output = reshape_layer.get_output(0) + + logger.debug(f"[TensorRT] Created advanced index with output shape {output_shape}") + return final_output @converter("aten.contiguous.default") diff --git a/backends/nvidia/tensorrt/converters/sub.py b/backends/nvidia/tensorrt/converters/sub.py index 011041b3c24..c05c22d86d5 100644 --- a/backends/nvidia/tensorrt/converters/sub.py +++ b/backends/nvidia/tensorrt/converters/sub.py @@ -6,6 +6,7 @@ """Converter for element-wise subtraction operations.""" +import logging from typing import Any, Dict, Optional import tensorrt as trt @@ -20,6 +21,9 @@ ) +logger: logging.Logger = logging.getLogger(__name__) + + def _get_elementwise_input( network: trt.INetworkDefinition, input_map: Dict[torch.fx.Node, Any], @@ -27,14 +31,32 @@ def _get_elementwise_input( name: str, dtype: Optional[torch.dtype], ) -> trt.ITensor: - """Get TensorRT tensor for an elementwise operation input.""" + """Get TensorRT tensor for an elementwise operation input. + + Handles: + - FX nodes already in input_map + - FX nodes that are lifted buffers/parameters (placeholder nodes with b_ or p_ prefix) + - Scalar values + """ if isinstance(arg, torch.fx.Node): - if arg not in input_map: - raise ValueError( - f"Input node '{arg.name}' not found in input_map. " - f"Available nodes: {list(input_map.keys())}" - ) - return input_map[arg] + if arg in input_map: + return input_map[arg] + + # Handle lifted buffers and parameters that aren't in input_map + # These are placeholder nodes with names starting with b_ (buffers) or p_ (parameters) + # or get_attr nodes. We need to create constants from their metadata values. + if arg.op == "placeholder" or arg.op == "get_attr": + if "val" in arg.meta and isinstance(arg.meta["val"], torch.Tensor): + logger.debug(f"[TensorRT] Creating constant for lifted buffer/parameter: {arg.name}") + trt_tensor = get_trt_tensor(network, arg.meta["val"], f"const_{arg.name}", dtype) + input_map[arg] = trt_tensor # Cache for future use + return trt_tensor + + raise ValueError( + f"Input node '{arg.name}' not found in input_map and could not be created as constant. " + f"Node op: {arg.op}, target: {arg.target}. " + f"Available nodes: {list(n.name for n in input_map.keys())}" + ) return get_trt_tensor(network, arg, name, dtype) diff --git a/backends/nvidia/tensorrt/converters/targets.bzl b/backends/nvidia/tensorrt/converters/targets.bzl index 23ce33fdb22..13d57eab477 100644 --- a/backends/nvidia/tensorrt/converters/targets.bzl +++ b/backends/nvidia/tensorrt/converters/targets.bzl @@ -32,12 +32,14 @@ def define_common_targets(): "permute_copy.py", "pixel_shuffle.py", "pooling.py", + "power.py", "reduction.py", "relu.py", "reshape.py", "sdpa.py", "slice.py", "sub.py", + "unary.py", "upsample.py", ], visibility = ["PUBLIC"], diff --git a/backends/nvidia/tensorrt/converters/unary.py b/backends/nvidia/tensorrt/converters/unary.py new file mode 100644 index 00000000000..841dd504522 --- /dev/null +++ b/backends/nvidia/tensorrt/converters/unary.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""TensorRT Converters for Unary Operations. + +Supported operations: +- aten.sqrt.default: Square root +- aten.rsqrt.default: Reciprocal square root (1/sqrt(x)) +- aten.exp.default: Exponential +- aten.log.default: Natural logarithm +- aten.neg.default: Negation +- aten.abs.default: Absolute value +- aten.sin.default: Sine +- aten.cos.default: Cosine +- aten.floor.default: Floor +- aten.ceil.default: Ceiling +- aten.erf.default: Error function +- aten.reciprocal.default: Reciprocal (1/x) + +These operations are commonly used in transformer models for layer normalization +and attention mechanisms. Design patterns follow TensorRT best practices. +""" + +import logging +from typing import Any, Dict, Optional + +import tensorrt as trt +import torch +from executorch.backends.nvidia.tensorrt.converter_registry import converter + +logger: logging.Logger = logging.getLogger(__name__) + + +def validate_unary(node: torch.fx.Node) -> bool: + """Validate that a unary node can be converted to TensorRT.""" + if node.op != "call_function": + return False + if len(node.args) < 1: + return False + if not isinstance(node.args[0], torch.fx.Node): + return False + return True + + +def _cast_to_float_if_needed( + network: trt.INetworkDefinition, + input_trt: trt.ITensor, + node_name: str, +) -> trt.ITensor: + """Cast integer tensors to float32 for operations that require float input. + + TensorRT unary operations like sqrt, exp, log, sin, cos, etc. don't support + int8 or int32 input types. This follows the TensorRT pattern of + automatically casting to float32. + + Args: + network: TensorRT network definition. + input_trt: Input tensor. + node_name: Node name for layer naming. + + Returns: + Input tensor (casted if necessary). + """ + if input_trt.dtype in (trt.int8, trt.int32): + cast_layer = network.add_cast(input_trt, trt.float32) + if cast_layer is None: + raise RuntimeError(f"Failed to create cast layer for {node_name}") + cast_layer.name = f"cast_to_float_{node_name}" + return cast_layer.get_output(0) + return input_trt + + +def _convert_unary_base( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + operation_type: trt.UnaryOperation, + op_name: str, + cast_to_float: bool = True, +) -> trt.ITensor: + """Base function for unary operation conversion. + + This follows the TensorRT architecture pattern where a base function + handles common logic (input validation, optional casting, layer creation) + and specific converters can call this with appropriate parameters. + + Args: + node: FX node to convert. + network: TensorRT network definition. + input_map: Map of FX nodes to TensorRT tensors. + operation_type: TensorRT unary operation type. + op_name: Name of the operation for logging/layer naming. + cast_to_float: Whether to cast int inputs to float32. + + Returns: + Output TensorRT tensor. + """ + logger.debug(f"[TensorRT] Converting {op_name} node: {node.name}") + + input_node = node.args[0] + if input_node not in input_map: + raise ValueError(f"Input node '{input_node.name}' not found in input_map") + + input_trt = input_map[input_node] + + if cast_to_float: + input_trt = _cast_to_float_if_needed(network, input_trt, node.name) + + layer = network.add_unary(input_trt, operation_type) + if layer is None: + raise RuntimeError(f"Failed to create {op_name} layer for {node.name}") + layer.name = f"{op_name}_{node.name}" + + return layer.get_output(0) + + +@converter("aten.sqrt.default", validator_fn=validate_unary) +def convert_sqrt( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch sqrt to TensorRT. + + PyTorch signature: aten.sqrt(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.SQRT, "sqrt" + ) + + +@converter("aten.rsqrt.default", validator_fn=validate_unary) +def convert_rsqrt( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch rsqrt to TensorRT. + + PyTorch signature: aten.rsqrt(Tensor self) -> Tensor + Computes 1/sqrt(x) + + Implemented as sqrt(x) followed by reciprocal, following TensorRT pattern. + """ + logger.debug(f"[TensorRT] Converting rsqrt node: {node.name}") + + input_node = node.args[0] + if input_node not in input_map: + raise ValueError(f"Input node '{input_node.name}' not found in input_map") + + input_trt = input_map[input_node] + input_trt = _cast_to_float_if_needed(network, input_trt, node.name) + + # First compute sqrt(x) + sqrt_layer = network.add_unary(input_trt, trt.UnaryOperation.SQRT) + if sqrt_layer is None: + raise RuntimeError(f"Failed to create sqrt layer for rsqrt {node.name}") + sqrt_layer.name = f"rsqrt_sqrt_{node.name}" + + # Then compute 1/sqrt(x) using reciprocal + recip_layer = network.add_unary(sqrt_layer.get_output(0), trt.UnaryOperation.RECIP) + if recip_layer is None: + raise RuntimeError(f"Failed to create reciprocal layer for rsqrt {node.name}") + recip_layer.name = f"rsqrt_{node.name}" + + return recip_layer.get_output(0) + + +@converter("aten.exp.default", validator_fn=validate_unary) +def convert_exp( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch exp to TensorRT. + + PyTorch signature: aten.exp(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.EXP, "exp" + ) + + +@converter("aten.log.default", validator_fn=validate_unary) +def convert_log( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch log (natural logarithm) to TensorRT. + + PyTorch signature: aten.log(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.LOG, "log" + ) + + +@converter("aten.neg.default", validator_fn=validate_unary) +def convert_neg( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch neg (negation) to TensorRT. + + PyTorch signature: aten.neg(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.NEG, "neg" + ) + + +@converter("aten.abs.default", validator_fn=validate_unary) +def convert_abs( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch abs to TensorRT. + + PyTorch signature: aten.abs(Tensor self) -> Tensor + Note: ABS supports all data types, no float casting needed. + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.ABS, "abs", cast_to_float=False + ) + + +@converter("aten.sin.default", validator_fn=validate_unary) +def convert_sin( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch sin to TensorRT. + + PyTorch signature: aten.sin(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.SIN, "sin" + ) + + +@converter("aten.cos.default", validator_fn=validate_unary) +def convert_cos( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch cos to TensorRT. + + PyTorch signature: aten.cos(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.COS, "cos" + ) + + +@converter("aten.floor.default", validator_fn=validate_unary) +def convert_floor( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch floor to TensorRT. + + PyTorch signature: aten.floor(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.FLOOR, "floor" + ) + + +@converter("aten.ceil.default", validator_fn=validate_unary) +def convert_ceil( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch ceil to TensorRT. + + PyTorch signature: aten.ceil(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.CEIL, "ceil" + ) + + +@converter("aten.erf.default", validator_fn=validate_unary) +def convert_erf( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch erf (error function) to TensorRT. + + PyTorch signature: aten.erf(Tensor self) -> Tensor + Used in GELU activation and other transformer operations. + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.ERF, "erf" + ) + + +@converter("aten.reciprocal.default", validator_fn=validate_unary) +def convert_reciprocal( + node: torch.fx.Node, + network: trt.INetworkDefinition, + input_map: Dict[torch.fx.Node, Any], + edge_program: Optional[Any] = None, +) -> trt.ITensor: + """Convert PyTorch reciprocal (1/x) to TensorRT. + + PyTorch signature: aten.reciprocal(Tensor self) -> Tensor + """ + return _convert_unary_base( + node, network, input_map, trt.UnaryOperation.RECIP, "reciprocal" + ) + + +__all__ = [ + "convert_sqrt", + "convert_rsqrt", + "convert_exp", + "convert_log", + "convert_neg", + "convert_abs", + "convert_sin", + "convert_cos", + "convert_floor", + "convert_ceil", + "convert_erf", + "convert_reciprocal", + "validate_unary", +] diff --git a/backends/nvidia/tensorrt/partitioner/operator_support.py b/backends/nvidia/tensorrt/partitioner/operator_support.py index fb216c9fd6a..0ec5fe08df4 100644 --- a/backends/nvidia/tensorrt/partitioner/operator_support.py +++ b/backends/nvidia/tensorrt/partitioner/operator_support.py @@ -31,6 +31,7 @@ class TensorRTOperatorSupport(OperatorSupportBase): "_scaled_dot_product_flash_attention.default", "_softmax.default", "_unsafe_view.default", + "abs.default", "adaptive_avg_pool2d.default", "add.Tensor", "add_.Tensor", @@ -42,6 +43,7 @@ class TensorRTOperatorSupport(OperatorSupportBase): "batch_norm.default", "bmm.default", "cat.default", + "ceil.default", "chunk.default", "clamp.default", "clamp_max.default", @@ -52,15 +54,19 @@ class TensorRTOperatorSupport(OperatorSupportBase): "copy.default", "conv2d.default", "convolution.default", + "cos.default", "div.Tensor", "div.Tensor_mode", "dropout.default", "dropout_.default", "embedding.default", "eq.Scalar", + "erf.default", + "exp.default", "expand.default", "expand_copy.default", "flatten.using_ints", + "floor.default", "full.default", "full_like.default", "ge.Scalar", @@ -75,6 +81,7 @@ class TensorRTOperatorSupport(OperatorSupportBase): "layer_norm.default", "le.Scalar", "linear.default", + "log.default", "log_softmax.int", "logical_not.default", "lt.Scalar", @@ -87,14 +94,19 @@ class TensorRTOperatorSupport(OperatorSupportBase): "mul_.Tensor", "native_layer_norm.default", "ne.Scalar", + "neg.default", "ones_like.default", "permute.default", "permute_copy.default", "pixel_shuffle.default", + "pow.Tensor_Scalar", + "pow.Tensor_Tensor", + "reciprocal.default", "relu.default", "relu_.default", "repeat.default", "reshape.default", + "rsqrt.default", "rsub.Scalar", "scalar_tensor.default", "scaled_dot_product_attention.default", @@ -102,18 +114,21 @@ class TensorRTOperatorSupport(OperatorSupportBase): "select_copy.int", "sigmoid.default", "silu.default", + "sin.default", "slice.Tensor", "slice_copy.Tensor", "softmax.int", "split.Tensor", "split_with_sizes.default", "split_with_sizes_copy.default", + "sqrt.default", "squeeze.dim", "squeeze.dims", "squeeze_copy.dim", "squeeze_copy.dims", "stack.default", "sub.Tensor", + "sum.dim_IntList", "tanh.default", "transpose.int", "unflatten.int", diff --git a/examples/nvidia/tensorrt/README.md b/examples/nvidia/tensorrt/README.md index 4ec66f24e9c..ebd2ec2b8f6 100644 --- a/examples/nvidia/tensorrt/README.md +++ b/examples/nvidia/tensorrt/README.md @@ -44,9 +44,21 @@ on systems with NVIDIA GPUs. ### Supported Models Currently supported models: -- `add` - Simple element-wise addition -More models will be added as converters are implemented. +**Toy Models:** +- `add`, `mul`, `linear`, `add_mul`, `softmax`, `conv1d` + +**Vision Models:** +- `mv2` (MobileNetV2), `mv3` (MobileNetV3), `resnet18`, `resnet50` +- `ic3` (InceptionV3), `ic4` (InceptionV4), `dl3` (DeepLabV3) +- `edsr` (Super-resolution), `efficient_sam` + +**Audio/NLP Models:** +- `w2l` (Wav2Letter), `mobilebert` +- `emformer_join`, `emformer_predict`, `emformer_transcribe` + +**Other:** +- `sdpa` (Scaled Dot-Product Attention) Run `--help` to see all available options: diff --git a/examples/nvidia/tensorrt/benchmark.py b/examples/nvidia/tensorrt/benchmark.py new file mode 100644 index 00000000000..3b97d058e94 --- /dev/null +++ b/examples/nvidia/tensorrt/benchmark.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +"""Benchmark script for TensorRT backend. + +Exports supported models with TensorRT delegate and prepares them for benchmarking. +Use the C++ benchmark runner for actual inference timing. + +Usage: + # Export models for benchmarking: + python -m executorch.examples.nvidia.tensorrt.benchmark -m mv2 mv3 + + # Run benchmark with C++ runner (after building with cmake): + ./cmake-out/backends/nvidia/tensorrt/benchmark_runner_tensorrt \ + --model_path=/tmp/benchmark/mv3_tensorrt.pte --num_executions=100 +""" + +import argparse +import logging +import os +from typing import Any, Dict, List, Tuple + +import torch +from torch.export import export + +from executorch.examples.models import MODEL_NAME_TO_MODEL +from executorch.examples.models.model_factory import EagerModelFactory +from executorch.exir import to_edge_transform_and_lower + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +# Models supported by TensorRT backend +TENSORRT_SUPPORTED_MODELS = [ + "add", + "mul", + "linear", + "add_mul", + "softmax", + "conv1d", + "mv2", + "mv3", + "resnet18", + "resnet50", + "w2l", + "ic3", + "ic4", + "dl3", + "edsr", + "emformer_join", + "sdpa", + "mobilebert", + "efficient_sam", +] + +# Default number of inference iterations +DEFAULT_NUM_ITERATIONS = 100 + +# Seed for reproducible random input generation +BENCHMARK_SEED = 2025 + + +def get_model_and_inputs(model_name: str) -> Tuple[torch.nn.Module, Tuple[Any, ...]]: + """Create model and example inputs from the model factory.""" + if model_name not in MODEL_NAME_TO_MODEL: + raise ValueError(f"Model {model_name} not in MODEL_NAME_TO_MODEL") + + model, example_inputs, _, _ = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[model_name] + ) + model.eval() + return model, example_inputs + + +def export_tensorrt( + model: torch.nn.Module, example_inputs: Tuple[Any, ...], output_path: str +) -> bool: + """Export model with TensorRT delegate and save to file.""" + try: + from executorch.backends.nvidia.tensorrt.partitioner import TensorRTPartitioner + + exported = export(model, example_inputs) + edge_program = to_edge_transform_and_lower( + exported, + partitioner=[TensorRTPartitioner()], + ) + exec_prog = edge_program.to_executorch() + + # Save to file + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "wb") as f: + f.write(exec_prog.buffer) + return True + except Exception as e: + logger.warning(f"TensorRT export failed: {e}") + return False + + +def benchmark_model( + model_name: str, + output_dir: str, + num_iterations: int = DEFAULT_NUM_ITERATIONS, +) -> Dict[str, Any]: + """Export a model with TensorRT for benchmarking.""" + logger.info(f"Exporting {model_name} for benchmarking...") + + result = { + "model": model_name, + "num_iterations": num_iterations, + "exported": False, + "pte_path": None, + } + + try: + model, example_inputs = get_model_and_inputs(model_name) + + # TensorRT export + trt_path = os.path.join(output_dir, f"{model_name}_tensorrt.pte") + if export_tensorrt(model, example_inputs, trt_path): + result["exported"] = True + result["pte_path"] = trt_path + logger.info(f" Exported to {trt_path}") + + except Exception as e: + logger.error(f"Error exporting {model_name}: {e}") + + return result + + +def print_results(results: List[Dict[str, Any]], output_dir: str) -> None: + """Print export results and C++ runner commands.""" + print("\n" + "=" * 80) + print("TENSORRT BENCHMARK EXPORT RESULTS") + print("=" * 80) + print(f"Output directory: {output_dir}") + print(f"Iterations: {results[0]['num_iterations'] if results else 'N/A'}") + print("-" * 80) + print(f"{'Model':<15} {'Status':<20} {'Path':<45}") + print("-" * 80) + + for r in results: + model = r["model"] + status = "Exported" if r["exported"] else "Failed" + path = r["pte_path"] or "N/A" + print(f"{model:<15} {status:<20} {path:<45}") + + print("-" * 80) + print("\nTo run benchmarks, use the C++ runner (after building with cmake):") + print("=" * 80) + for r in results: + if r["exported"]: + num_iter = r["num_iterations"] + pte_path = r["pte_path"] + print( + f"# {r['model']}:\n" + f"./cmake-out/backends/nvidia/tensorrt/benchmark_runner_tensorrt " + f"--model_path={pte_path} --num_executions={num_iter}\n" + ) + print("=" * 80) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark TensorRT backend") + parser.add_argument( + "-m", + "--models", + nargs="+", + default=TENSORRT_SUPPORTED_MODELS, + help=f"Models to benchmark. Default: {TENSORRT_SUPPORTED_MODELS}", + ) + parser.add_argument( + "-n", + "--num_iterations", + type=int, + default=DEFAULT_NUM_ITERATIONS, + help=f"Number of inference iterations per model. Default: {DEFAULT_NUM_ITERATIONS}", + ) + parser.add_argument( + "-o", + "--output_dir", + default="/tmp/benchmark", + help="Output directory for exported models. Default: /tmp/benchmark", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + args = parser.parse_args() + + if args.verbose: + logger.setLevel(logging.DEBUG) + + os.makedirs(args.output_dir, exist_ok=True) + + results = [] + for model_name in args.models: + if model_name not in TENSORRT_SUPPORTED_MODELS: + logger.warning( + f"Model {model_name} not in supported models, skipping. " + f"Supported: {TENSORRT_SUPPORTED_MODELS}" + ) + continue + result = benchmark_model(model_name, args.output_dir, args.num_iterations) + results.append(result) + + print_results(results, args.output_dir) + + +if __name__ == "__main__": + with torch.no_grad(): + main() diff --git a/examples/nvidia/tensorrt/export.py b/examples/nvidia/tensorrt/export.py index 5cb0f8607c0..07dcdea2c38 100644 --- a/examples/nvidia/tensorrt/export.py +++ b/examples/nvidia/tensorrt/export.py @@ -43,6 +43,7 @@ # "emformer_predict", # TODO: passes 1/3 seeds — precision sensitive with randomized inputs "emformer_transcribe", "ic3", + "ic4", "linear", "mul", "mv2", diff --git a/examples/nvidia/tensorrt/tests/TARGETS b/examples/nvidia/tensorrt/tests/TARGETS index 71c30ea2622..ccd03ac9d19 100644 --- a/examples/nvidia/tensorrt/tests/TARGETS +++ b/examples/nvidia/tensorrt/tests/TARGETS @@ -49,6 +49,61 @@ manifold_get( visibility = ["PUBLIC"], ) +manifold_get( + name = "efficient_sam_weights", + out = "efficient_sam_vitt.pt", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/efficient_sam_vitt.pt", + bucket_name = "executorch", + sha1 = "5087209b2c62c518450344e8447f79ae29e32a21", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + +manifold_get( + name = "resnet18_weights", + out = "resnet18-f37072fd.pth", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/resnet18-f37072fd.pth", + bucket_name = "executorch", + sha1 = "93e13d94f74fdf476689608f146f47bde96b30b0", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + +manifold_get( + name = "resnet50_weights", + out = "resnet50-0676ba61.pth", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/resnet50-0676ba61.pth", + bucket_name = "executorch", + sha1 = "6ba9789036078cf8bace8dd75a770f46789c350c", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + +manifold_get( + name = "mv3_weights", + out = "mobilenet_v3_small-047dcff4.pth", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/mobilenet_v3_small-047dcff4.pth", + bucket_name = "executorch", + sha1 = "af9828929cb043737714380ab5af08b7ef76c5b2", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + +manifold_get( + name = "emformer_weights", + out = "emformer_rnnt_base_librispeech.pt", + api_key = "executorch-key", + artifact_path = "tree/models/tensorrt/weights/emformer_rnnt_base_librispeech.pt", + bucket_name = "executorch", + sha1 = "7b2a073aebce641dd9c471f454099f1bcb534208", + timeout_msec = 120000, + visibility = ["PUBLIC"], +) + # Export correctness tests: exports each supported model with TensorRT # and compares inference outputs against eager PyTorch on GPU. # buck2 test fbcode//executorch/examples/nvidia/tensorrt/tests:test_export @@ -61,8 +116,13 @@ python_unittest_remote_gpu( "HTTP_PROXY": "http://fwdproxy.any:8080", "DOG_JPG": "$(location :dog_jpg)", "EDSR_WEIGHTS": "$(location :edsr_weights)", + "EFFICIENT_SAM_WEIGHTS": "$(location :efficient_sam_weights)", + "EMFORMER_WEIGHTS": "$(location :emformer_weights)", "IC4_WEIGHTS": "$(location :ic4_weights)", "MV2_WEIGHTS": "$(location :mv2_weights)", + "MV3_WEIGHTS": "$(location :mv3_weights)", + "RESNET18_WEIGHTS": "$(location :resnet18_weights)", + "RESNET50_WEIGHTS": "$(location :resnet50_weights)", }, keep_gpu_sections = True, labels = ["long_running"], diff --git a/examples/nvidia/tensorrt/tests/test_export.py b/examples/nvidia/tensorrt/tests/test_export.py index 2906f56b01e..7905355514b 100644 --- a/examples/nvidia/tensorrt/tests/test_export.py +++ b/examples/nvidia/tensorrt/tests/test_export.py @@ -22,10 +22,18 @@ logging.basicConfig(level=logging.INFO) # Mapping from env var to expected cache filename. -# The test TARGETS provides these via manifold_get + $(location). +# When set (e.g., by CI), these env vars point to pre-downloaded weight files +# that get copied into the torch cache to avoid network downloads. _WEIGHT_ENV_VARS = { + "DOG_JPG": "dog.jpg", "EDSR_WEIGHTS": "edsr64_x2.pt", + "EFFICIENT_SAM_WEIGHTS": "efficient_sam_vitt.pt", + "EMFORMER_WEIGHTS": "emformer_rnnt_base_librispeech.pt", + "IC4_WEIGHTS": "inceptionv4-8e4777a0.pth", + "MV2_WEIGHTS": "mobilenet_v2-b0353104.pth", "MV3_WEIGHTS": "mobilenet_v3_small-047dcff4.pth", + "RESNET18_WEIGHTS": "resnet18-f37072fd.pth", + "RESNET50_WEIGHTS": "resnet50-0676ba61.pth", } @@ -41,21 +49,14 @@ def _populate_weight_cache() -> None: src = os.environ.get(env_var) if src and os.path.isfile(src): if env_var == "DOG_JPG": - # MV2Model downloads dog.jpg to CWD dst = os.path.join(os.getcwd(), filename) - elif env_var.startswith("MOBILEBERT_"): - # Pre-populate HuggingFace cache for mobilebert - hf_dir = os.path.join( - os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")), - "hub", "models--google--mobilebert-uncased", - "snapshots", "manifold", + elif env_var == "EMFORMER_WEIGHTS": + torchaudio_dir = os.path.join( + os.environ.get("TORCH_HOME", os.path.expanduser("~/.cache/torch")), + "hub", "torchaudio", "models", ) - os.makedirs(hf_dir, exist_ok=True) - refs_dir = os.path.join(os.path.dirname(hf_dir), "refs") - os.makedirs(refs_dir, exist_ok=True) - with open(os.path.join(refs_dir, "main"), "w") as rf: - rf.write("manifold") - dst = os.path.join(hf_dir, filename) + os.makedirs(torchaudio_dir, exist_ok=True) + dst = os.path.join(torchaudio_dir, filename) else: dst = os.path.join(cache_dir, filename) if not os.path.exists(dst): @@ -155,3 +156,15 @@ def test_resnet18(self) -> None: def test_resnet50(self) -> None: _export_and_verify("resnet50") + + def test_edsr(self) -> None: + _export_and_verify("edsr") + + def test_emformer_transcribe(self) -> None: + _export_and_verify("emformer_transcribe") + + def test_efficient_sam(self) -> None: + _export_and_verify("efficient_sam") + + def test_ic4(self) -> None: + _export_and_verify("ic4")