From 4035db7465723a4037a6bc7358f067006ec3f371 Mon Sep 17 00:00:00 2001 From: shoumikhin Date: Thu, 5 Mar 2026 09:58:16 -0800 Subject: [PATCH] [executorch][nvidia][tensorrt][11/n] Complete preprocess integration with serialization Complete preprocess integration with blob serialization for TensorRT engine compilation. Differential Revision: [D93275051](https://our.internmc.facebook.com/intern/diff/D93275051/) [ghstack-poisoned] --- backends/nvidia/tensorrt/backend.py | 73 +++++++++++++++++++++++++++- backends/nvidia/tensorrt/targets.bzl | 1 + 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/backends/nvidia/tensorrt/backend.py b/backends/nvidia/tensorrt/backend.py index 3484ee9ce30..29185b8559a 100644 --- a/backends/nvidia/tensorrt/backend.py +++ b/backends/nvidia/tensorrt/backend.py @@ -25,6 +25,11 @@ lookup_converter, needs_edge_program, ) +from executorch.backends.nvidia.tensorrt.serialization import ( + serialize_blob, + TensorRTBlobMetadata, + TensorRTIOBinding, +) logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -100,6 +105,9 @@ def preprocess( ) _mark_network_outputs(network, output_nodes, input_map) + # Collect I/O bindings from network + io_bindings = _collect_io_bindings(network) + # Configure and build engine config = _create_builder_config(builder, spec, trt) serialized_engine = builder.build_serialized_network(network, config) @@ -107,7 +115,11 @@ def preprocess( if serialized_engine is None: raise RuntimeError("Failed to build TensorRT engine") - return PreprocessResult(processed_bytes=bytes(serialized_engine)) + # Serialize with metadata + metadata = TensorRTBlobMetadata(io_bindings=io_bindings) + blob = serialize_blob(bytes(serialized_engine), metadata) + + return PreprocessResult(processed_bytes=blob) def _get_input_nodes( @@ -284,6 +296,65 @@ def _mark_network_outputs( network.mark_output(output_tensor) +def _trt_dtype_to_string(dtype: Any) -> str: + """Convert TensorRT DataType to string representation.""" + dtype_name = str(dtype) + # dtype looks like "DataType.FLOAT" or "DataType.HALF" + if "." in dtype_name: + dtype_name = dtype_name.split(".")[-1] + + dtype_map = { + "FLOAT": "float32", + "HALF": "float16", + "INT8": "int8", + "INT32": "int32", + "INT64": "int64", + "BOOL": "bool", + "UINT8": "uint8", + "FP8": "float8", + "BF16": "bfloat16", + } + return dtype_map.get(dtype_name, "float32") + + +def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]: + """Collect I/O binding information from TensorRT network. + + Args: + network: TensorRT network definition. + + Returns: + List of TensorRTIOBinding with input/output tensor metadata. + """ + bindings = [] + + # Collect inputs + for i in range(network.num_inputs): + tensor = network.get_input(i) + bindings.append( + TensorRTIOBinding( + name=tensor.name, + dtype=_trt_dtype_to_string(tensor.dtype), + shape=list(tensor.shape), + is_input=True, + ) + ) + + # Collect outputs + for i in range(network.num_outputs): + tensor = network.get_output(i) + bindings.append( + TensorRTIOBinding( + name=tensor.name, + dtype=_trt_dtype_to_string(tensor.dtype), + shape=list(tensor.shape), + is_input=False, + ) + ) + + return bindings + + def _create_builder_config(builder: Any, spec: TensorRTCompileSpec, trt: Any) -> Any: """Create and configure TensorRT builder config.""" config = builder.create_builder_config() diff --git a/backends/nvidia/tensorrt/targets.bzl b/backends/nvidia/tensorrt/targets.bzl index ad42dca0152..689b55668e5 100644 --- a/backends/nvidia/tensorrt/targets.bzl +++ b/backends/nvidia/tensorrt/targets.bzl @@ -19,6 +19,7 @@ def define_common_targets(): "//executorch/backends/nvidia/tensorrt:compile_spec", "//executorch/backends/nvidia/tensorrt:converter_registry", "//executorch/backends/nvidia/tensorrt:converter_utils", + "//executorch/backends/nvidia/tensorrt:serialization", "//executorch/backends/nvidia/tensorrt/converters:converters", "//executorch/exir/backend:backend_details", ],