Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions backends/nvidia/tensorrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ cmake -B cmake-out \
-DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \
-DCMAKE_BUILD_TYPE=Release

cmake --build cmake-out --target tensorrt_executor_runner -j$(nproc)
cmake --build cmake-out --target tensorrt_executor_runner benchmark -j$(nproc)
```

### Export and Run Models
Expand All @@ -49,6 +49,9 @@ python -m executorch.examples.nvidia.tensorrt.export -m add
# Export all supported models
python -m executorch.examples.nvidia.tensorrt.export

# Export with ONNX baseline (for benchmarking)
python -m executorch.examples.nvidia.tensorrt.export --onnx

# Run inference with the C++ runner
./cmake-out/backends/nvidia/tensorrt/tensorrt_executor_runner --model_path=add_tensorrt.pte
```
Expand Down Expand Up @@ -110,7 +113,7 @@ Alternatively, download and install from the
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `workspace_size` | int | 1GB | TensorRT builder workspace size |
| `precision` | TensorRTPrecision | FP32 | Inference precision (FP32, FP16, INT8) |
| `precision` | TensorRTPrecision | FP32 | Inference precision (FP32, FP16, BF16, INT8) |
| `strict_type_constraints` | bool | False | Enforce strict type constraints |
| `max_batch_size` | int | 1 | Maximum batch size |
| `device_id` | int | 0 | CUDA device ID |
Expand Down Expand Up @@ -158,15 +161,9 @@ with open("model_tensorrt.pte", "wb") as f:
## Supported Operations

| Category | Operations |
|----------|-----------|
| Elementwise | add, sub, mul, div, mm, relu |
| Reshape | view, reshape, squeeze, unsqueeze, permute, contiguous, clone |

## Supported Operations

| Category | Operations |
|----------|-----------|
| Elementwise | add, sub, mul, div, floor_divide, rsub |
|----------|------------|
| Elementwise | add, sub, mul, div, floor_divide, rsub, pow, abs, ceil, sqrt |
| Unary Math | cos, sin, exp, erf, log |
| Matrix | mm, addmm, bmm, linear |
| Convolution | conv2d |
| Normalization | batch_norm, layer_norm |
Expand All @@ -176,8 +173,9 @@ with open("model_tensorrt.pte", "wb") as f:
| Reduction | mean, any |
| Concat/Split | cat, split, chunk, stack |
| Comparison | eq, ne, lt, le, gt, ge, where, logical_not |
| Slicing | slice, select, index |
| Other | embedding, expand, repeat, upsample, pixel_shuffle, scaled_dot_product_attention, full |
| Slicing | slice, select, index, arange |
| Padding | constant_pad_nd |
| Other | embedding, expand, repeat, upsample, pixel_shuffle, scaled_dot_product_attention, full, dropout |

## Jetson Deployment

Expand Down Expand Up @@ -252,6 +250,21 @@ Each test exports a model with TensorRT, runs inference via ExecuTorch
pybindings, and compares outputs against eager PyTorch (atol=1e-3, rtol=1e-3)
across 3 random seeds.

### Benchmark

```bash
# Benchmark all exported models in the current directory
./cmake-out/examples/nvidia/tensorrt/benchmark

# Benchmark with options
./cmake-out/examples/nvidia/tensorrt/benchmark -d DIR -m MODEL -n 100 -w 5
```

The benchmark reports three formats per model:
- **pte** — ExecuTorch end-to-end (includes framework overhead)
- **pte-raw** — Raw TRT engine execution extracted from the .pte
- **onnx-trt** — ONNX → TRT engine (baseline, when .onnx files are present)

## Troubleshooting

| Issue | Fix |
Expand Down
21 changes: 7 additions & 14 deletions backends/nvidia/tensorrt/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,8 @@ def _add_params_to_input_map(
if isinstance(param_tensor, torch.nn.Parameter):
param_tensor = param_tensor.data
if isinstance(param_tensor, torch.Tensor):
# Convert int64/int32 tensors to float32 for TensorRT compatibility
# These are often used in elementwise operations with float tensors
# (e.g., batch norm statistics in MobileNetV3)
original_dtype = param_tensor.dtype
if param_tensor.dtype in (torch.int32, torch.int64):
param_tensor = param_tensor.float()
logger.debug(
f"Converting param {node.name} from {original_dtype} to float32 "
f"for TensorRT compatibility"
)
elif param_tensor.dtype == torch.float64:
param_tensor = param_tensor.float()
# get_trt_tensor handles dtype conversion (int64→int32, float64→float32)
# via create_constant in converter_utils.py
input_map[node] = get_trt_tensor_fn(
network, param_tensor, f"param_{node.name}"
)
Expand Down Expand Up @@ -394,6 +384,9 @@ def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]:
Returns:
List of TensorRTIOBinding with input/output tensor metadata.
"""
# Import here to avoid circular imports at module level
from executorch.backends.nvidia.tensorrt.converter_utils import get_safe_shape

bindings = []

# Collect inputs
Expand All @@ -403,7 +396,7 @@ def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]:
TensorRTIOBinding(
name=tensor.name,
dtype=_trt_dtype_to_string(tensor.dtype),
shape=list(tensor.shape),
shape=get_safe_shape(tensor),
is_input=True,
)
)
Expand All @@ -415,7 +408,7 @@ def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]:
TensorRTIOBinding(
name=tensor.name,
dtype=_trt_dtype_to_string(tensor.dtype),
shape=list(tensor.shape),
shape=get_safe_shape(tensor),
is_input=False,
)
)
Expand Down
70 changes: 65 additions & 5 deletions backends/nvidia/tensorrt/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -269,8 +269,11 @@
Handles:
- TensorRT ITensor (returned as-is)
- Python scalars (int, float) → constant tensor
- PyTorch tensors → constant tensor
- PyTorch tensors → constant tensor (including FakeTensors/subclasses)
- numpy arrays → constant tensor

Note: Uses unset_fake_temporarily to handle tensor subclasses like FakeTensor
that don't support .numpy() directly. This follows the TensorRT pattern.
"""
if isinstance(value, trt.ITensor):
return value
Expand All @@ -283,8 +286,32 @@
return create_constant(network, value, name)

if isinstance(value, torch.Tensor):
value = _tensor_to_numpy(value)
return create_constant(network, value, name)
# Handle tensor subclasses (FakeTensor, etc.) that don't support .numpy()
# by temporarily exiting fake tensor mode. This follows TensorRT pattern.
try:
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
with unset_fake_temporarily():
# Create a real tensor from the fake tensor's data if needed
if hasattr(value, '_local_scalar_dense') or not value.is_contiguous():
value = value.contiguous()
np_value = value.detach().cpu().numpy()
except (ImportError, RuntimeError):
# Fallback: try to convert via creating a new tensor
try:
np_value = _tensor_to_numpy(value)
except RuntimeError:
# Last resort: create tensor from metadata if available
if hasattr(value, 'shape') and hasattr(value, 'dtype'):
# For FakeTensors, we may need to create a zero tensor as placeholder
# This should only happen during tracing, not actual execution
np_dtype = _torch_dtype_to_numpy(value.dtype)
np_value = np.zeros(tuple(value.shape), dtype=np_dtype)
else:
raise RuntimeError(
f"Cannot convert tensor subclass {type(value)} to numpy. "
f"Tensor may be a FakeTensor from tracing."
)
return create_constant(network, np_value, name)

if isinstance(value, np.ndarray):
return create_constant(network, value, name)
Expand All @@ -300,6 +327,8 @@
"""Create a TensorRT constant tensor from numpy array.

Note: TensorRT doesn't support int64 (i64), so we convert to int32.
Also, TensorRT doesn't handle 0-d tensors well in elementwise ops,
so we reshape scalars to 1-d tensors with shape (1,).
"""
# TensorRT doesn't support int64 - convert to int32
if value.dtype == np.int64:
Expand All @@ -308,11 +337,39 @@
if value.dtype == np.float64:
value = value.astype(np.float32)

# TensorRT requires at least 1-d tensors for elementwise ops.
# Reshape 0-d scalars to 1-d tensors with shape (1,).
if value.ndim == 0:
value = value.reshape((1,))

layer = network.add_constant(value.shape, trt.Weights(value))
layer.name = f"const_{name}"
return layer.get_output(0)


def get_safe_shape(tensor: trt.ITensor) -> List[int]:
"""Get tensor shape safely, handling dynamic shapes.

TensorRT tensors can have invalid shapes during network building
(e.g., negative length for dynamic dimensions). This function
safely extracts the shape as a list.

Args:
tensor: TensorRT tensor to get shape from.

Returns:
List of dimension sizes, or empty list if shape is invalid.
"""
try:
shape = tensor.shape
if shape is None:
return []
shape_list = list(shape)
return shape_list
except (ValueError, TypeError):
return []


def broadcast_tensors(
network: trt.INetworkDefinition,
tensors: Sequence[trt.ITensor],
Expand All @@ -335,10 +392,13 @@
"""
result = []
for i, tensor in enumerate(tensors):
current_ndim = len(tensor.shape)
shape = get_safe_shape(tensor)
current_ndim = len(shape) if shape else target_ndim

if current_ndim < target_ndim:
diff = target_ndim - current_ndim
new_shape = (1,) * diff + tuple(tensor.shape)
existing_shape = tuple(shape) if shape else tuple([-1] * current_ndim)
new_shape = (1,) * diff + existing_shape
layer = network.add_shuffle(tensor)
layer.reshape_dims = new_shape
# Use context counter if available, otherwise use simple naming
Expand Down
2 changes: 2 additions & 0 deletions backends/nvidia/tensorrt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from executorch.backends.nvidia.tensorrt.converters import permute_copy # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import pixel_shuffle # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import pooling # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import power # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import reduction # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import relu # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import reshape # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import sdpa # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import slice # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import sub # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import unary # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import upsample # noqa: F401


Expand Down
5 changes: 3 additions & 2 deletions backends/nvidia/tensorrt/converters/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import torch
from executorch.backends.nvidia.tensorrt.converter_registry import converter
from executorch.backends.nvidia.tensorrt.converter_utils import get_node_shape

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -489,8 +490,8 @@ def convert_log_softmax(

input_trt = input_map[input_node]

# Handle negative dimension
input_shape = input_trt.shape
# Handle negative dimension using TRT tensor shape.
input_shape = tuple(input_trt.shape)
ndim = len(input_shape)
if dim < 0:
dim = ndim + dim
Expand Down
50 changes: 38 additions & 12 deletions backends/nvidia/tensorrt/converters/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

"""Converter for element-wise addition operations."""

import logging
from typing import Any, Dict, Optional

import tensorrt as trt
Expand All @@ -20,6 +21,9 @@
)


logger: logging.Logger = logging.getLogger(__name__)


def _get_input_ndim(arg: Any, input_map: Dict[torch.fx.Node, Any]) -> int:
"""Get the number of dimensions for an elementwise input argument.

Expand All @@ -36,13 +40,18 @@ def _get_input_ndim(arg: Any, input_map: Dict[torch.fx.Node, Any]) -> int:
# Try to get ndim from node metadata first (most reliable)
if "val" in arg.meta and hasattr(arg.meta["val"], "shape"):
return len(arg.meta["val"].shape)
# Fall back to TRT tensor shape
# Fall back to TRT tensor shape - handle dynamic shapes carefully
if arg in input_map:
trt_tensor = input_map[arg]
shape = trt_tensor.shape
if shape is not None:
return len(shape)
# For scalars, return 0 (will be broadcast)
try:
shape = trt_tensor.shape
if shape is not None:
ndim = len(shape)
if ndim >= 0: # Valid shape
return ndim
except (ValueError, TypeError):
pass # Dynamic shape, fall through
# For scalars or unknown, return 0 (will be broadcast)
return 0


Expand All @@ -55,6 +64,11 @@ def _get_elementwise_input(
) -> trt.ITensor:
"""Get TensorRT tensor for an elementwise operation input.

Handles:
- FX nodes already in input_map
- FX nodes that are lifted buffers/parameters (placeholder nodes with b_ or p_ prefix)
- Scalar values

Args:
network: TensorRT network definition.
input_map: Mapping from FX nodes to TensorRT tensors.
Expand All @@ -66,15 +80,27 @@ def _get_elementwise_input(
TensorRT tensor for the input.

Raises:
ValueError: If arg is a Node but not found in input_map.
ValueError: If arg is a Node but not found in input_map and cannot be created as constant.
"""
if isinstance(arg, torch.fx.Node):
if arg not in input_map:
raise ValueError(
f"Input node '{arg.name}' not found in input_map. "
f"Available nodes: {list(input_map.keys())}"
)
return input_map[arg]
if arg in input_map:
return input_map[arg]

# Handle lifted buffers and parameters that aren't in input_map
# These are placeholder nodes with names starting with b_ (buffers) or p_ (parameters)
# or get_attr nodes. We need to create constants from their metadata values.
if arg.op == "placeholder" or arg.op == "get_attr":
if "val" in arg.meta and isinstance(arg.meta["val"], torch.Tensor):
logger.debug(f"[TensorRT] Creating constant for lifted buffer/parameter: {arg.name}")
trt_tensor = get_trt_tensor(network, arg.meta["val"], f"const_{arg.name}", dtype)
input_map[arg] = trt_tensor # Cache for future use
return trt_tensor

raise ValueError(
f"Input node '{arg.name}' not found in input_map and could not be created as constant. "
f"Node op: {arg.op}, target: {arg.target}. "
f"Available nodes: {list(n.name for n in input_map.keys())}"
)
return get_trt_tensor(network, arg, name, dtype)


Expand Down
Loading
Loading