Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/nvidia/tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
sys.path.append(_pkg_dir)

from executorch.backends.nvidia.tensorrt.backend import TensorRTBackend
from executorch.backends.nvidia.tensorrt.partitioner import TensorRTPartitioner

__all__ = [
"TensorRTBackend",
"TensorRTPartitioner",
]
5 changes: 5 additions & 0 deletions backends/nvidia/tensorrt/partitioner/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
45 changes: 45 additions & 0 deletions backends/nvidia/tensorrt/partitioner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""TensorRT partitioner for ExecuTorch."""

from typing import Dict, List, Optional

from executorch.backends.nvidia.tensorrt.backend import TensorRTBackend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from torch.export.exported_program import ExportedProgram


class TensorRTPartitioner(Partitioner):
"""Partitioner for TensorRT backend.
"""

def __init__(
self,
compile_specs: Optional[List[CompileSpec]] = None,
) -> None:
super().__init__()
self.compile_specs = compile_specs or []
self.delegation_spec = DelegationSpec(
backend_id=TensorRTBackend.__name__,
compile_specs=self.compile_specs,
)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""Partition the graph for TensorRT delegation.

Identifies subgraphs that can be lowered to the TensorRT backend.
"""
partition_tags: Dict[str, DelegationSpec] = {}
return PartitionResult(
tagged_exported_program=exported_program,
partition_tags=partition_tags,
)
21 changes: 21 additions & 0 deletions backends/nvidia/tensorrt/partitioner/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

runtime.python_library(
name = "partitioner",
srcs = [
"__init__.py",
],
visibility = ["PUBLIC"],
deps = [
"//caffe2:torch",
"//executorch/backends/nvidia/tensorrt:backend",
"//executorch/exir/backend:partitioner",
],
)
Loading