Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions backends/nvidia/tensorrt/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
TensorRTBlobMetadata,
TensorRTIOBinding,
)
from executorch.backends.nvidia.tensorrt.converters import (
clear_converter_weight_storage,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -100,6 +103,10 @@ def preprocess(

# Build the network
input_map = _add_network_inputs(network, input_nodes, torch_dtype_to_trt)
# Add params/buffers as constant tensors
_add_params_to_input_map(
graph_module, edge_program, network, input_map, get_trt_tensor
)
_process_graph_nodes(
graph_module, edge_program, network, input_map, get_trt_tensor, get_op_name, ctx
)
Expand Down Expand Up @@ -173,6 +180,67 @@ def _is_param_or_buffer(
return False


def _add_params_to_input_map(
graph_module: torch.fx.GraphModule,
exported_program: ExportedProgram,
network: Any,
input_map: Dict[torch.fx.Node, Any],
get_trt_tensor_fn: Any,
) -> None:
"""Add parameters and buffers as constant TensorRT tensors to input_map.

In ExecuTorch's edge dialect, parameters are often "lifted" as placeholder
inputs rather than get_attr nodes. This function identifies these placeholder
nodes that represent parameters/buffers and adds them to input_map as
TensorRT constant tensors.
"""
for node in graph_module.graph.nodes:
if node.op == "placeholder":
# Skip if already in input_map (it's a real input, not a param)
if node in input_map:
continue

param_tensor = None

# Try to get from state_dict first
if hasattr(exported_program, "state_dict"):
if node.name in exported_program.state_dict:
param_tensor = exported_program.state_dict[node.name]

# Try to get from graph_signature mapping
if param_tensor is None and hasattr(exported_program, "graph_signature"):
sig = exported_program.graph_signature
param_name = None
if hasattr(sig, "inputs_to_parameters"):
param_name = sig.inputs_to_parameters.get(node.name)
if param_name is None and hasattr(sig, "inputs_to_buffers"):
param_name = sig.inputs_to_buffers.get(node.name)

if param_name is not None and hasattr(exported_program, "state_dict"):
param_tensor = exported_program.state_dict.get(param_name)

# If we found a parameter tensor, add it to input_map
if param_tensor is not None:
if isinstance(param_tensor, torch.nn.Parameter):
param_tensor = param_tensor.data
if isinstance(param_tensor, torch.Tensor):
# Convert int64/int32 tensors to float32 for TensorRT compatibility
# These are often used in elementwise operations with float tensors
# (e.g., batch norm statistics in MobileNetV3)
original_dtype = param_tensor.dtype
if param_tensor.dtype in (torch.int32, torch.int64):
param_tensor = param_tensor.float()
logger.debug(
f"Converting param {node.name} from {original_dtype} to float32 "
f"for TensorRT compatibility"
)
elif param_tensor.dtype == torch.float64:
param_tensor = param_tensor.float()
input_map[node] = get_trt_tensor_fn(
network, param_tensor, f"param_{node.name}"
)


def _get_tensor_shape_and_dtype(
node: torch.fx.Node,
) -> Tuple[Optional[Tuple[int, ...]], Optional[torch.dtype]]:
Expand Down
18 changes: 18 additions & 0 deletions backends/nvidia/tensorrt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,29 @@
"""TensorRT converters for ExecuTorch operations."""

# Import converters to trigger registration via @converter decorator
from executorch.backends.nvidia.tensorrt.converters import activations # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import add # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import addmm # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import batch_norm # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import clamp # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import concat # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import conv2d # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import dim_order_ops # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import div # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import getitem # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import linear # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import mm # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import mul # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import permute_copy # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import pooling # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import reduction # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import relu # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import reshape # noqa: F401
from executorch.backends.nvidia.tensorrt.converters import sub # noqa: F401


def clear_converter_weight_storage() -> None:
"""Clear weight storage to free memory after engine build."""
conv2d.clear_weight_storage()
batch_norm.clear_weight_storage()
linear.clear_weight_storage()
Loading
Loading