Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion backends/nvidia/tensorrt/backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -25,6 +25,11 @@
lookup_converter,
needs_edge_program,
)
from executorch.backends.nvidia.tensorrt.serialization import (
serialize_blob,
TensorRTBlobMetadata,
TensorRTIOBinding,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -100,14 +105,21 @@
)
_mark_network_outputs(network, output_nodes, input_map)

# Collect I/O bindings from network
io_bindings = _collect_io_bindings(network)

# Configure and build engine
config = _create_builder_config(builder, spec, trt)
serialized_engine = builder.build_serialized_network(network, config)

if serialized_engine is None:
raise RuntimeError("Failed to build TensorRT engine")

return PreprocessResult(processed_bytes=bytes(serialized_engine))
# Serialize with metadata
metadata = TensorRTBlobMetadata(io_bindings=io_bindings)
blob = serialize_blob(bytes(serialized_engine), metadata)

return PreprocessResult(processed_bytes=blob)


def _get_input_nodes(
Expand Down Expand Up @@ -284,7 +296,66 @@
network.mark_output(output_tensor)


def _trt_dtype_to_string(dtype: Any) -> str:
"""Convert TensorRT DataType to string representation."""
dtype_name = str(dtype)
# dtype looks like "DataType.FLOAT" or "DataType.HALF"
if "." in dtype_name:
dtype_name = dtype_name.split(".")[-1]

dtype_map = {
"FLOAT": "float32",
"HALF": "float16",
"INT8": "int8",
"INT32": "int32",
"INT64": "int64",
"BOOL": "bool",
"UINT8": "uint8",
"FP8": "float8",
"BF16": "bfloat16",
}
return dtype_map.get(dtype_name, "float32")


def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]:
"""Collect I/O binding information from TensorRT network.

Args:
network: TensorRT network definition.

Returns:
List of TensorRTIOBinding with input/output tensor metadata.
"""
bindings = []

# Collect inputs
for i in range(network.num_inputs):
tensor = network.get_input(i)
bindings.append(
TensorRTIOBinding(
name=tensor.name,
dtype=_trt_dtype_to_string(tensor.dtype),
shape=list(tensor.shape),
is_input=True,
)
)

# Collect outputs
for i in range(network.num_outputs):
tensor = network.get_output(i)
bindings.append(
TensorRTIOBinding(
name=tensor.name,
dtype=_trt_dtype_to_string(tensor.dtype),
shape=list(tensor.shape),
is_input=False,
)
)

return bindings


def _create_builder_config(builder: Any, spec: TensorRTCompileSpec, trt: Any) -> Any:

Check warning on line 358 in backends/nvidia/tensorrt/backend.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'_create_builder_config' is too complex (19) See https://www.flake8rules.com/rules/C901.html.
"""Create and configure TensorRT builder config."""
config = builder.create_builder_config()
if config is None:
Expand All @@ -305,7 +376,7 @@
# Report build progress if TRT supports IProgressMonitor.
if hasattr(trt, "IProgressMonitor"):

class _ProgressMonitor(trt.IProgressMonitor):

Check warning on line 379 in backends/nvidia/tensorrt/backend.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F811

redefinition of unused '_ProgressMonitor' from line 373 See https://www.flake8rules.com/rules/F811.html.
def __init__(self):
trt.IProgressMonitor.__init__(self)
self._seen = set()
Expand Down
1 change: 1 addition & 0 deletions backends/nvidia/tensorrt/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def define_common_targets():
"//executorch/backends/nvidia/tensorrt:compile_spec",
"//executorch/backends/nvidia/tensorrt:converter_registry",
"//executorch/backends/nvidia/tensorrt:converter_utils",
"//executorch/backends/nvidia/tensorrt:serialization",
"//executorch/backends/nvidia/tensorrt/converters:converters",
"//executorch/exir/backend:backend_details",
],
Expand Down
Loading