Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,11 @@ if(EXECUTORCH_BUILD_METAL)
list(APPEND _executorch_backends metal_backend)
endif()

if(EXECUTORCH_BUILD_TENSORRT)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/nvidia/tensorrt)
list(APPEND _executorch_backends tensorrt_backend)
endif()

if(EXECUTORCH_BUILD_EXTENSION_APPLE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/apple)
endif()
Expand Down Expand Up @@ -983,6 +988,10 @@ if(EXECUTORCH_BUILD_PYBIND)
list(APPEND _dep_libs vulkan_backend)
endif()

if(EXECUTORCH_BUILD_TENSORRT)
list(APPEND _dep_libs tensorrt_backend)
endif()

# compile options for pybind
set(_pybind_compile_options
$<$<CXX_COMPILER_ID:MSVC>:/EHsc
Expand Down
125 changes: 124 additions & 1 deletion backends/nvidia/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -27,4 +27,127 @@
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
endif()

include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
# When built as part of the main ExecuTorch CMake, Utils.cmake is already included.
# Only include it if executorch_target_link_options_shared_lib is not defined.
if(NOT COMMAND executorch_target_link_options_shared_lib)
if(EXISTS ${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
else()
# Define a no-op fallback for standalone builds
function(executorch_target_link_options_shared_lib target)
# No-op: whole-archive linking is handled separately for the runner
endfunction()
endif()
endif()

if(EXECUTORCH_BUILD_TENSORRT)
# Find TensorRT package
find_package(TensorRT QUIET)

if(NOT TensorRT_FOUND)
# Try to find TensorRT manually via CMake variable or environment variable
if(DEFINED TENSORRT_HOME)
set(TENSORRT_ROOT ${TENSORRT_HOME})
elseif(DEFINED ENV{TENSORRT_HOME})
set(TENSORRT_ROOT $ENV{TENSORRT_HOME})
elseif(DEFINED ENV{TENSORRT_DIR})
set(TENSORRT_ROOT $ENV{TENSORRT_DIR})
else()
# Default to /usr for system installations (e.g., JetPack on Jetson)
set(TENSORRT_ROOT "/usr")
endif()

message(STATUS "Looking for TensorRT in: ${TENSORRT_ROOT}")

# Find TensorRT headers (supports both x86_64 and aarch64/Jetson)
find_path(
TENSORRT_INCLUDE_DIR NvInfer.h
PATHS
${TENSORRT_ROOT}/include
${TENSORRT_ROOT}/include/aarch64-linux-gnu
${TENSORRT_ROOT}/include/x86_64-linux-gnu
PATH_SUFFIXES tensorrt
)

# Find TensorRT library (supports both x86_64 and aarch64/Jetson)
find_library(
TENSORRT_LIBRARY nvinfer
PATHS
${TENSORRT_ROOT}/lib
${TENSORRT_ROOT}/lib/aarch64-linux-gnu
${TENSORRT_ROOT}/lib/x86_64-linux-gnu
${TENSORRT_ROOT}/lib64
)

if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY)
message(STATUS "Found TensorRT include: ${TENSORRT_INCLUDE_DIR}")
message(STATUS "Found TensorRT library: ${TENSORRT_LIBRARY}")

# Get the directory containing TensorRT library and add to link directories
get_filename_component(TENSORRT_LIB_DIR ${TENSORRT_LIBRARY} DIRECTORY)
message(STATUS "TensorRT library directory: ${TENSORRT_LIB_DIR}")
link_directories(${TENSORRT_LIB_DIR})
endif()
endif()

# Verify TensorRT was found
if(NOT TensorRT_FOUND AND NOT TENSORRT_LIBRARY)
message(FATAL_ERROR
"TensorRT not found. Set TENSORRT_HOME or TENSORRT_DIR environment variable, "
"or pass -DTENSORRT_HOME=/path/to/tensorrt to CMake.")
endif()

# Find CUDA
find_package(CUDAToolkit REQUIRED)

# Define common include directories (used by backend and runner/test binaries)
set(TENSORRT_COMMON_INCLUDE_DIRS ${EXECUTORCH_ROOT}/..)

# TensorRT backend static library
add_library(tensorrt_backend STATIC)

# Enable exceptions and RTTI for TensorRT backend
target_compile_options(tensorrt_backend PRIVATE -frtti -fexceptions)

target_include_directories(
tensorrt_backend
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
$<BUILD_INTERFACE:${TENSORRT_COMMON_INCLUDE_DIRS}>
$<INSTALL_INTERFACE:include>
)

if(TENSORRT_INCLUDE_DIR)
target_include_directories(
tensorrt_backend PUBLIC $<BUILD_INTERFACE:${TENSORRT_INCLUDE_DIR}>
)
endif()

# Add source files
target_sources(
tensorrt_backend
PRIVATE ${CMAKE_CURRENT_LIST_DIR}/runtime/TensorRTBackend.cpp
${CMAKE_CURRENT_LIST_DIR}/runtime/TensorRTExecutor.cpp
)

# Link dependencies
target_link_libraries(tensorrt_backend PUBLIC executorch_core CUDA::cudart)

if(TENSORRT_LIBRARY)
target_link_libraries(tensorrt_backend PUBLIC ${TENSORRT_LIBRARY})
elseif(TensorRT_FOUND)
target_link_libraries(tensorrt_backend PUBLIC TensorRT::nvinfer)
endif()

# Force link the whole library to ensure static registration works
executorch_target_link_options_shared_lib(tensorrt_backend)

# Install TensorRT backend library
install(
TARGETS tensorrt_backend
EXPORT ExecuTorchTargets
DESTINATION ${CMAKE_INSTALL_LIBDIR}
INCLUDES
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)

endif()
11 changes: 11 additions & 0 deletions backends/nvidia/tensorrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,14 @@ The TensorRT delegate uses a custom binary blob format:
- CUDA Toolkit 11.x or 12.x
- cuDNN 8.x
- PyTorch 2.x with CUDA support (for export)

## Build Instructions

```bash
cd executorch
mkdir -p cmake-out && cd cmake-out

cmake .. -DEXECUTORCH_BUILD_TENSORRT=ON

cmake --build . --target tensorrt_backend tensorrt_executor_runner
```
5 changes: 5 additions & 0 deletions backends/nvidia/tensorrt/runtime/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
190 changes: 190 additions & 0 deletions backends/nvidia/tensorrt/runtime/TensorRTBackend.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/nvidia/tensorrt/runtime/TensorRTBackend.h>
#include <executorch/backends/nvidia/tensorrt/runtime/TensorRTBlobHeader.h>
#include <executorch/backends/nvidia/tensorrt/runtime/TensorRTExecutor.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/platform/log.h>

namespace executorch {
namespace backends {
namespace tensorrt {

using executorch::runtime::ArrayRef;
using executorch::runtime::Backend;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::CompileSpec;
using executorch::runtime::DelegateHandle;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::register_backend;
using executorch::runtime::Result;
using executorch::runtime::Span;

namespace {

bool is_tensorrt_available() {
return true;
}

} // namespace

bool TensorRTBackend::is_available() const {
return is_tensorrt_available();
}

Result<DelegateHandle*> TensorRTBackend::init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const {
(void)compile_specs;

if (!is_available()) {
ET_LOG(Error, "TensorRT backend is not available");
return Error::NotSupported;
}

if (processed == nullptr || processed->data() == nullptr) {
ET_LOG(Error, "Invalid processed buffer");
return Error::InvalidArgument;
}

const void* blob_data = processed->data();
const size_t blob_size = processed->size();

TensorRTBlobHeader header{};
if (!parse_blob_header(blob_data, blob_size, header)) {
ET_LOG(Error, "Failed to parse TensorRT blob header");
return Error::InvalidArgument;
}

MemoryAllocator* allocator =
context.get_runtime_allocator();
if (allocator == nullptr) {
ET_LOG(Error, "Failed to get runtime allocator");
return Error::InvalidState;
}

TensorRTExecutor* executor =
allocator->allocateInstance<TensorRTExecutor>();
if (executor == nullptr) {
ET_LOG(Error, "Failed to allocate TensorRT executor");
return Error::MemoryAllocationFailed;
}

new (executor) TensorRTExecutor();

Error err = executor->initialize(blob_data, blob_size);
if (err != Error::Ok) {
ET_LOG(Error, "Failed to initialize TensorRT executor");
executor->~TensorRTExecutor();
return err;
}

processed->Free();

return static_cast<DelegateHandle*>(executor);
}

Error TensorRTBackend::execute(
BackendExecutionContext& context,
DelegateHandle* handle,
Span<EValue*> args) const {
(void)context;

if (handle == nullptr) {
ET_LOG(Error, "Invalid delegate handle");
return Error::InvalidArgument;
}

auto* executor = static_cast<TensorRTExecutor*>(handle);

if (!executor->is_initialized()) {
ET_LOG(Error, "Executor not initialized");
return Error::InvalidState;
}

size_t num_inputs = executor->get_num_inputs();
size_t num_outputs = executor->get_num_outputs();

if (num_inputs + num_outputs == 0) {
ET_LOG(Error, "No inputs or outputs found");
return Error::InvalidState;
}

std::vector<void*> input_buffers;
std::vector<void*> output_buffers;
input_buffers.reserve(num_inputs);
output_buffers.reserve(num_outputs);

size_t tensor_idx = 0;
for (size_t i = 0; i < args.size(); ++i) {
EValue* arg = args[i];
if (arg == nullptr || !arg->isTensor()) {
continue;
}

::executorch::aten::Tensor tensor = arg->toTensor();
void* data_ptr = tensor.mutable_data_ptr();

if (tensor_idx < num_inputs) {
input_buffers.push_back(data_ptr);
} else {
output_buffers.push_back(data_ptr);
}
++tensor_idx;
}

if (input_buffers.size() != num_inputs) {
ET_LOG(
Error,
"Input buffer count mismatch: expected %zu, got %zu",
num_inputs,
input_buffers.size());
return Error::InvalidArgument;
}

if (output_buffers.size() != num_outputs) {
ET_LOG(
Error,
"Output buffer count mismatch: expected %zu, got %zu",
num_outputs,
output_buffers.size());
return Error::InvalidArgument;
}

return executor->execute(
input_buffers.data(),
input_buffers.size(),
output_buffers.data(),
output_buffers.size());
}

void TensorRTBackend::destroy(DelegateHandle* handle) const {
if (handle != nullptr) {
auto* executor = static_cast<TensorRTExecutor*>(handle);
executor->~TensorRTExecutor();
}
}

} // namespace tensorrt
} // namespace backends
} // namespace executorch

namespace {
executorch::backends::tensorrt::TensorRTBackend& get_backend() {
static executorch::backends::tensorrt::TensorRTBackend backend;
return backend;
}
const executorch::runtime::Backend backend_id{"TensorRTBackend", &get_backend()};
const auto registered = executorch::runtime::register_backend(backend_id);
} // namespace
Loading
Loading