-
Notifications
You must be signed in to change notification settings - Fork 971
Expand file tree
/
Copy pathbackend.py
More file actions
488 lines (400 loc) · 16.5 KB
/
backend.py
File metadata and controls
488 lines (400 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""TensorRT backend implementation for ExecuTorch."""
import logging
from typing import Any, Dict, final, List, Optional, Tuple
import torch
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from torch.export.exported_program import ExportedProgram
from executorch.backends.nvidia.tensorrt.compile_spec import (
TensorRTCompileSpec,
TensorRTPrecision,
)
from executorch.backends.nvidia.tensorrt.converter_registry import (
lookup_converter,
needs_edge_program,
)
from executorch.backends.nvidia.tensorrt.serialization import (
serialize_blob,
TensorRTBlobMetadata,
TensorRTIOBinding,
)
from executorch.backends.nvidia.tensorrt.converters import (
clear_converter_weight_storage,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
@final
class TensorRTBackend(BackendDetails):
"""TensorRT backend for accelerating models on NVIDIA GPUs.
This backend compiles ExecuTorch edge programs to TensorRT engines
for optimized inference on NVIDIA hardware.
"""
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
"""Compile edge program to TensorRT engine.
Args:
edge_program: The edge dialect program to compile.
compile_specs: Backend-specific compilation options.
Returns:
PreprocessResult containing the serialized TensorRT engine.
"""
try:
import tensorrt as trt
except ImportError as e:
raise RuntimeError(
"TensorRT is not available. Please install TensorRT to use this backend."
) from e
# Import converters to trigger registration
from executorch.backends.nvidia.tensorrt import ( # noqa: F401
converters as _converters,
)
from executorch.backends.nvidia.tensorrt.converter_utils import (
ConversionContext,
get_op_name,
get_trt_tensor,
torch_dtype_to_trt,
)
# Parse compile specs
spec = TensorRTCompileSpec.from_compile_specs(compile_specs)
if spec is None:
spec = TensorRTCompileSpec()
graph_module = edge_program.graph_module
# Identify input and output nodes
input_nodes = _get_input_nodes(graph_module, edge_program)
output_nodes = _get_output_nodes(graph_module)
# Create TensorRT builder and network
trt_logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(trt_logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
if network is None:
raise RuntimeError("Failed to create TensorRT network")
# Create conversion context for this build
ctx = ConversionContext(net=network)
# Build the network
input_map = _add_network_inputs(network, input_nodes, torch_dtype_to_trt)
# Add params/buffers as constant tensors
_add_params_to_input_map(
graph_module, edge_program, network, input_map, get_trt_tensor
)
_process_graph_nodes(
graph_module, edge_program, network, input_map, get_trt_tensor, get_op_name, ctx
)
_mark_network_outputs(network, output_nodes, input_map)
# Collect I/O bindings from network
io_bindings = _collect_io_bindings(network)
# Configure and build engine
config = _create_builder_config(builder, spec, trt)
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
raise RuntimeError("Failed to build TensorRT engine")
# Serialize with metadata
metadata = TensorRTBlobMetadata(io_bindings=io_bindings)
blob = serialize_blob(bytes(serialized_engine), metadata)
return PreprocessResult(processed_bytes=blob)
def _get_input_nodes(
graph_module: torch.fx.GraphModule,
exported_program: ExportedProgram,
) -> List[torch.fx.Node]:
"""Get graph input placeholder nodes (excluding parameters/buffers)."""
input_nodes = []
for node in graph_module.graph.nodes:
if node.op == "placeholder":
if not _is_param_or_buffer(node, exported_program):
input_nodes.append(node)
return input_nodes
def _get_output_nodes(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:
"""Get nodes that are graph outputs."""
output_nodes = []
for node in graph_module.graph.nodes:
if node.op == "output":
for arg in node.args:
if isinstance(arg, (list, tuple)):
output_nodes.extend(
item for item in arg if isinstance(item, torch.fx.Node)
)
elif isinstance(arg, torch.fx.Node):
output_nodes.append(arg)
return output_nodes
def _is_param_or_buffer(
node: torch.fx.Node, exported_program: ExportedProgram
) -> bool:
"""Check if a placeholder node is a parameter or buffer."""
if node.op != "placeholder":
return False
if hasattr(exported_program, "state_dict"):
if node.name in exported_program.state_dict:
return True
if hasattr(exported_program, "graph_signature"):
sig = exported_program.graph_signature
if hasattr(sig, "inputs_to_parameters"):
if node.name in sig.inputs_to_parameters:
return True
if hasattr(sig, "inputs_to_buffers"):
if node.name in sig.inputs_to_buffers:
return True
return False
def _add_params_to_input_map(
graph_module: torch.fx.GraphModule,
exported_program: ExportedProgram,
network: Any,
input_map: Dict[torch.fx.Node, Any],
get_trt_tensor_fn: Any,
) -> None:
"""Add parameters and buffers as constant TensorRT tensors to input_map.
In ExecuTorch's edge dialect, parameters are often "lifted" as placeholder
inputs rather than get_attr nodes. This function identifies these placeholder
nodes that represent parameters/buffers and adds them to input_map as
TensorRT constant tensors.
"""
for node in graph_module.graph.nodes:
if node.op == "placeholder":
# Skip if already in input_map (it's a real input, not a param)
if node in input_map:
continue
param_tensor = None
# Try to get from state_dict first
if hasattr(exported_program, "state_dict"):
if node.name in exported_program.state_dict:
param_tensor = exported_program.state_dict[node.name]
# Try to get from graph_signature mapping
if param_tensor is None and hasattr(exported_program, "graph_signature"):
sig = exported_program.graph_signature
param_name = None
if hasattr(sig, "inputs_to_parameters"):
param_name = sig.inputs_to_parameters.get(node.name)
if param_name is None and hasattr(sig, "inputs_to_buffers"):
param_name = sig.inputs_to_buffers.get(node.name)
if param_name is not None and hasattr(exported_program, "state_dict"):
param_tensor = exported_program.state_dict.get(param_name)
# If we found a parameter tensor, add it to input_map
if param_tensor is not None:
if isinstance(param_tensor, torch.nn.Parameter):
param_tensor = param_tensor.data
if isinstance(param_tensor, torch.Tensor):
# get_trt_tensor handles dtype conversion (int64→int32, float64→float32)
# via create_constant in converter_utils.py
input_map[node] = get_trt_tensor_fn(
network, param_tensor, f"param_{node.name}"
)
def _get_tensor_shape_and_dtype(
node: torch.fx.Node,
) -> Tuple[Optional[Tuple[int, ...]], Optional[torch.dtype]]:
"""Extract tensor shape and dtype from node metadata."""
if "val" in node.meta:
val = node.meta["val"]
if isinstance(val, torch.Tensor):
return tuple(val.shape), val.dtype
if hasattr(val, "shape") and hasattr(val, "dtype"):
return tuple(val.shape), val.dtype
return None, None
def _get_attr_value(
graph_module: torch.fx.GraphModule, attr_name: str
) -> Optional[torch.Tensor]:
"""Get attribute value from graph module."""
try:
parts = attr_name.split(".")
obj = graph_module
for part in parts:
obj = getattr(obj, part)
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, torch.nn.Parameter):
return obj.data
return None
except AttributeError:
return None
def _add_network_inputs(
network: Any,
input_nodes: List[torch.fx.Node],
dtype_converter: Any,
) -> Dict[torch.fx.Node, Any]:
"""Add input tensors to TensorRT network."""
input_map: Dict[torch.fx.Node, Any] = {}
for input_node in input_nodes:
shape, dtype = _get_tensor_shape_and_dtype(input_node)
if shape is None:
raise RuntimeError(
f"Cannot determine shape for input node: {input_node.name}"
)
trt_dtype = dtype_converter(dtype if dtype else torch.float32)
trt_input = network.add_input(
name=input_node.name,
dtype=trt_dtype,
shape=shape,
)
if trt_input is None:
raise RuntimeError(f"Failed to add input to network: {input_node.name}")
input_map[input_node] = trt_input
return input_map
def _process_graph_nodes(
graph_module: torch.fx.GraphModule,
exported_program: ExportedProgram,
network: Any,
input_map: Dict[torch.fx.Node, Any],
get_trt_tensor_fn: Any,
get_op_name_fn: Any,
ctx: Any = None,
) -> None:
"""Process graph nodes and convert to TensorRT layers.
Args:
graph_module: The FX graph module to process.
exported_program: The ExportedProgram for weight extraction.
network: TensorRT network definition.
input_map: Mapping from FX nodes to TensorRT tensors.
get_trt_tensor_fn: Function to create TensorRT constant tensors.
get_op_name_fn: Function to extract operation name from nodes.
ctx: Optional ConversionContext for unique layer naming.
"""
for node in graph_module.graph.nodes:
if node.op == "call_function":
op_name = get_op_name_fn(node)
converter = lookup_converter(op_name)
if converter is None:
raise RuntimeError(f"No converter registered for operation: {op_name}")
# Check if converter needs edge_program for weight extraction
if needs_edge_program(op_name):
output_tensor = converter(node, network, input_map, exported_program, ctx)
else:
output_tensor = converter(node, network, input_map, ctx)
input_map[node] = output_tensor
elif node.op == "get_attr":
attr_name = node.target
param = _get_attr_value(graph_module, attr_name)
if param is not None:
input_map[node] = get_trt_tensor_fn(
network, param, f"param_{node.name}"
)
def _mark_network_outputs(
network: Any,
output_nodes: List[torch.fx.Node],
input_map: Dict[torch.fx.Node, Any],
) -> None:
"""Mark network outputs in TensorRT network."""
for output_node in output_nodes:
if output_node not in input_map:
raise RuntimeError(
f"Output node not found in input_map: {output_node.name}"
)
output_tensor = input_map[output_node]
if hasattr(output_tensor, "name"):
output_tensor.name = f"output_{output_node.name}"
network.mark_output(output_tensor)
def _trt_dtype_to_string(dtype: Any) -> str:
"""Convert TensorRT DataType to string representation."""
dtype_name = str(dtype)
# dtype looks like "DataType.FLOAT" or "DataType.HALF"
if "." in dtype_name:
dtype_name = dtype_name.split(".")[-1]
dtype_map = {
"FLOAT": "float32",
"HALF": "float16",
"INT8": "int8",
"INT32": "int32",
"INT64": "int64",
"BOOL": "bool",
"UINT8": "uint8",
"FP8": "float8",
"BF16": "bfloat16",
}
return dtype_map.get(dtype_name, "float32")
def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]:
"""Collect I/O binding information from TensorRT network.
Args:
network: TensorRT network definition.
Returns:
List of TensorRTIOBinding with input/output tensor metadata.
"""
# Import here to avoid circular imports at module level
from executorch.backends.nvidia.tensorrt.converter_utils import get_safe_shape
bindings = []
# Collect inputs
for i in range(network.num_inputs):
tensor = network.get_input(i)
bindings.append(
TensorRTIOBinding(
name=tensor.name,
dtype=_trt_dtype_to_string(tensor.dtype),
shape=get_safe_shape(tensor),
is_input=True,
)
)
# Collect outputs
for i in range(network.num_outputs):
tensor = network.get_output(i)
bindings.append(
TensorRTIOBinding(
name=tensor.name,
dtype=_trt_dtype_to_string(tensor.dtype),
shape=get_safe_shape(tensor),
is_input=False,
)
)
return bindings
def _create_builder_config(builder: Any, spec: TensorRTCompileSpec, trt: Any) -> Any:
"""Create and configure TensorRT builder config."""
config = builder.create_builder_config()
if config is None:
raise RuntimeError("Failed to create TensorRT builder config")
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, spec.workspace_size)
# Disable TF32 for strict FP32 precision on Ampere+ GPUs.
if hasattr(trt.BuilderFlag, "TF32"):
config.clear_flag(trt.BuilderFlag.TF32)
# Report build progress if TRT supports IProgressMonitor.
if hasattr(trt, "IProgressMonitor"):
class _ProgressMonitor(trt.IProgressMonitor):
def __init__(self):
self._seen = set()
# Report build progress if TRT supports IProgressMonitor.
if hasattr(trt, "IProgressMonitor"):
class _ProgressMonitor(trt.IProgressMonitor):
def __init__(self):
trt.IProgressMonitor.__init__(self)
self._seen = set()
def phase_start(self, phase_name, parent_phase, num_steps):
key = (phase_name, parent_phase)
if key not in self._seen:
self._seen.add(key)
indent = " " if parent_phase else " "
print(f"{indent}TRT: {phase_name}", flush=True)
def step_complete(self, phase_name, step):
return True
def phase_finish(self, phase_name):
pass
config.progress_monitor = _ProgressMonitor()
# TensorRT 10.6+ enables WEIGHT_STREAMING by default, which generates
# weight-separated plan files that require IStreamReader for deserialization.
# We disable this flag to generate standard plan files that can be
# deserialized with the simpler deserializeCudaEngine(data, size) API.
if hasattr(trt.BuilderFlag, "WEIGHT_STREAMING"):
config.clear_flag(trt.BuilderFlag.WEIGHT_STREAMING)
if spec.precision == TensorRTPrecision.FP16:
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
else:
logger.warning("FP16 not supported on this platform, using FP32")
if spec.precision == TensorRTPrecision.INT8:
if builder.platform_has_fast_int8:
config.set_flag(trt.BuilderFlag.INT8)
else:
logger.warning("INT8 not supported on this platform, using FP32")
if spec.strict_type_constraints:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
if spec.dla_core >= 0:
config.default_device_type = trt.DeviceType.DLA
config.DLA_core = spec.dla_core
if spec.allow_gpu_fallback:
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
return config