diff --git a/CMakeLists.txt b/CMakeLists.txt index 995a75c342b..e4db14887bd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -716,6 +716,11 @@ if(EXECUTORCH_BUILD_METAL) list(APPEND _executorch_backends metal_backend) endif() +if(EXECUTORCH_BUILD_TENSORRT) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/nvidia/tensorrt) + list(APPEND _executorch_backends tensorrt_backend) +endif() + if(EXECUTORCH_BUILD_EXTENSION_APPLE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/apple) endif() @@ -983,6 +988,10 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs vulkan_backend) endif() + if(EXECUTORCH_BUILD_TENSORRT) + list(APPEND _dep_libs tensorrt_backend) + endif() + # compile options for pybind set(_pybind_compile_options $<$:/EHsc diff --git a/backends/nvidia/tensorrt/CMakeLists.txt b/backends/nvidia/tensorrt/CMakeLists.txt index 8856f0e4c3e..2b6ae6c414b 100644 --- a/backends/nvidia/tensorrt/CMakeLists.txt +++ b/backends/nvidia/tensorrt/CMakeLists.txt @@ -27,4 +27,127 @@ if(NOT EXECUTORCH_ROOT) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) endif() -include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +# When built as part of the main ExecuTorch CMake, Utils.cmake is already included. +# Only include it if executorch_target_link_options_shared_lib is not defined. +if(NOT COMMAND executorch_target_link_options_shared_lib) + if(EXISTS ${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + else() + # Define a no-op fallback for standalone builds + function(executorch_target_link_options_shared_lib target) + # No-op: whole-archive linking is handled separately for the runner + endfunction() + endif() +endif() + +if(EXECUTORCH_BUILD_TENSORRT) + # Find TensorRT package + find_package(TensorRT QUIET) + + if(NOT TensorRT_FOUND) + # Try to find TensorRT manually via CMake variable or environment variable + if(DEFINED TENSORRT_HOME) + set(TENSORRT_ROOT ${TENSORRT_HOME}) + elseif(DEFINED ENV{TENSORRT_HOME}) + set(TENSORRT_ROOT $ENV{TENSORRT_HOME}) + elseif(DEFINED ENV{TENSORRT_DIR}) + set(TENSORRT_ROOT $ENV{TENSORRT_DIR}) + else() + # Default to /usr for system installations (e.g., JetPack on Jetson) + set(TENSORRT_ROOT "/usr") + endif() + + message(STATUS "Looking for TensorRT in: ${TENSORRT_ROOT}") + + # Find TensorRT headers (supports both x86_64 and aarch64/Jetson) + find_path( + TENSORRT_INCLUDE_DIR NvInfer.h + PATHS + ${TENSORRT_ROOT}/include + ${TENSORRT_ROOT}/include/aarch64-linux-gnu + ${TENSORRT_ROOT}/include/x86_64-linux-gnu + PATH_SUFFIXES tensorrt + ) + + # Find TensorRT library (supports both x86_64 and aarch64/Jetson) + find_library( + TENSORRT_LIBRARY nvinfer + PATHS + ${TENSORRT_ROOT}/lib + ${TENSORRT_ROOT}/lib/aarch64-linux-gnu + ${TENSORRT_ROOT}/lib/x86_64-linux-gnu + ${TENSORRT_ROOT}/lib64 + ) + + if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY) + message(STATUS "Found TensorRT include: ${TENSORRT_INCLUDE_DIR}") + message(STATUS "Found TensorRT library: ${TENSORRT_LIBRARY}") + + # Get the directory containing TensorRT library and add to link directories + get_filename_component(TENSORRT_LIB_DIR ${TENSORRT_LIBRARY} DIRECTORY) + message(STATUS "TensorRT library directory: ${TENSORRT_LIB_DIR}") + link_directories(${TENSORRT_LIB_DIR}) + endif() + endif() + + # Verify TensorRT was found + if(NOT TensorRT_FOUND AND NOT TENSORRT_LIBRARY) + message(FATAL_ERROR + "TensorRT not found. Set TENSORRT_HOME or TENSORRT_DIR environment variable, " + "or pass -DTENSORRT_HOME=/path/to/tensorrt to CMake.") + endif() + + # Find CUDA + find_package(CUDAToolkit REQUIRED) + + # Define common include directories (used by backend and runner/test binaries) + set(TENSORRT_COMMON_INCLUDE_DIRS ${EXECUTORCH_ROOT}/..) + + # TensorRT backend static library + add_library(tensorrt_backend STATIC) + + # Enable exceptions and RTTI for TensorRT backend + target_compile_options(tensorrt_backend PRIVATE -frtti -fexceptions) + + target_include_directories( + tensorrt_backend + PUBLIC $ + $ + $ + ) + + if(TENSORRT_INCLUDE_DIR) + target_include_directories( + tensorrt_backend PUBLIC $ + ) + endif() + + # Add source files + target_sources( + tensorrt_backend + PRIVATE ${CMAKE_CURRENT_LIST_DIR}/runtime/TensorRTBackend.cpp + ${CMAKE_CURRENT_LIST_DIR}/runtime/TensorRTExecutor.cpp + ) + + # Link dependencies + target_link_libraries(tensorrt_backend PUBLIC executorch_core CUDA::cudart) + + if(TENSORRT_LIBRARY) + target_link_libraries(tensorrt_backend PUBLIC ${TENSORRT_LIBRARY}) + elseif(TensorRT_FOUND) + target_link_libraries(tensorrt_backend PUBLIC TensorRT::nvinfer) + endif() + + # Force link the whole library to ensure static registration works + executorch_target_link_options_shared_lib(tensorrt_backend) + + # Install TensorRT backend library + install( + TARGETS tensorrt_backend + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} + INCLUDES + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + ) + +endif() diff --git a/backends/nvidia/tensorrt/README.md b/backends/nvidia/tensorrt/README.md index 0f8ccc01f73..3f8a7c01efe 100644 --- a/backends/nvidia/tensorrt/README.md +++ b/backends/nvidia/tensorrt/README.md @@ -106,3 +106,14 @@ The TensorRT delegate uses a custom binary blob format: - CUDA Toolkit 11.x or 12.x - cuDNN 8.x - PyTorch 2.x with CUDA support (for export) + +## Build Instructions + +```bash +cd executorch +mkdir -p cmake-out && cd cmake-out + +cmake .. -DEXECUTORCH_BUILD_TENSORRT=ON + +cmake --build . --target tensorrt_backend tensorrt_executor_runner +``` diff --git a/backends/nvidia/tensorrt/runtime/TARGETS b/backends/nvidia/tensorrt/runtime/TARGETS new file mode 100644 index 00000000000..0a42614a385 --- /dev/null +++ b/backends/nvidia/tensorrt/runtime/TARGETS @@ -0,0 +1,5 @@ +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/nvidia/tensorrt/runtime/TensorRTBackend.cpp b/backends/nvidia/tensorrt/runtime/TensorRTBackend.cpp new file mode 100644 index 00000000000..95e969d2126 --- /dev/null +++ b/backends/nvidia/tensorrt/runtime/TensorRTBackend.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace tensorrt { + +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::register_backend; +using executorch::runtime::Result; +using executorch::runtime::Span; + +namespace { + +bool is_tensorrt_available() { + return true; +} + +} // namespace + +bool TensorRTBackend::is_available() const { + return is_tensorrt_available(); +} + +Result TensorRTBackend::init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const { + (void)compile_specs; + + if (!is_available()) { + ET_LOG(Error, "TensorRT backend is not available"); + return Error::NotSupported; + } + + if (processed == nullptr || processed->data() == nullptr) { + ET_LOG(Error, "Invalid processed buffer"); + return Error::InvalidArgument; + } + + const void* blob_data = processed->data(); + const size_t blob_size = processed->size(); + + TensorRTBlobHeader header{}; + if (!parse_blob_header(blob_data, blob_size, header)) { + ET_LOG(Error, "Failed to parse TensorRT blob header"); + return Error::InvalidArgument; + } + + MemoryAllocator* allocator = + context.get_runtime_allocator(); + if (allocator == nullptr) { + ET_LOG(Error, "Failed to get runtime allocator"); + return Error::InvalidState; + } + + TensorRTExecutor* executor = + allocator->allocateInstance(); + if (executor == nullptr) { + ET_LOG(Error, "Failed to allocate TensorRT executor"); + return Error::MemoryAllocationFailed; + } + + new (executor) TensorRTExecutor(); + + Error err = executor->initialize(blob_data, blob_size); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to initialize TensorRT executor"); + executor->~TensorRTExecutor(); + return err; + } + + processed->Free(); + + return static_cast(executor); +} + +Error TensorRTBackend::execute( + BackendExecutionContext& context, + DelegateHandle* handle, + Span args) const { + (void)context; + + if (handle == nullptr) { + ET_LOG(Error, "Invalid delegate handle"); + return Error::InvalidArgument; + } + + auto* executor = static_cast(handle); + + if (!executor->is_initialized()) { + ET_LOG(Error, "Executor not initialized"); + return Error::InvalidState; + } + + size_t num_inputs = executor->get_num_inputs(); + size_t num_outputs = executor->get_num_outputs(); + + if (num_inputs + num_outputs == 0) { + ET_LOG(Error, "No inputs or outputs found"); + return Error::InvalidState; + } + + std::vector input_buffers; + std::vector output_buffers; + input_buffers.reserve(num_inputs); + output_buffers.reserve(num_outputs); + + size_t tensor_idx = 0; + for (size_t i = 0; i < args.size(); ++i) { + EValue* arg = args[i]; + if (arg == nullptr || !arg->isTensor()) { + continue; + } + + ::executorch::aten::Tensor tensor = arg->toTensor(); + void* data_ptr = tensor.mutable_data_ptr(); + + if (tensor_idx < num_inputs) { + input_buffers.push_back(data_ptr); + } else { + output_buffers.push_back(data_ptr); + } + ++tensor_idx; + } + + if (input_buffers.size() != num_inputs) { + ET_LOG( + Error, + "Input buffer count mismatch: expected %zu, got %zu", + num_inputs, + input_buffers.size()); + return Error::InvalidArgument; + } + + if (output_buffers.size() != num_outputs) { + ET_LOG( + Error, + "Output buffer count mismatch: expected %zu, got %zu", + num_outputs, + output_buffers.size()); + return Error::InvalidArgument; + } + + return executor->execute( + input_buffers.data(), + input_buffers.size(), + output_buffers.data(), + output_buffers.size()); +} + +void TensorRTBackend::destroy(DelegateHandle* handle) const { + if (handle != nullptr) { + auto* executor = static_cast(handle); + executor->~TensorRTExecutor(); + } +} + +} // namespace tensorrt +} // namespace backends +} // namespace executorch + +namespace { +executorch::backends::tensorrt::TensorRTBackend& get_backend() { + static executorch::backends::tensorrt::TensorRTBackend backend; + return backend; +} +const executorch::runtime::Backend backend_id{"TensorRTBackend", &get_backend()}; +const auto registered = executorch::runtime::register_backend(backend_id); +} // namespace diff --git a/backends/nvidia/tensorrt/runtime/TensorRTBackend.h b/backends/nvidia/tensorrt/runtime/TensorRTBackend.h new file mode 100644 index 00000000000..67cb7481939 --- /dev/null +++ b/backends/nvidia/tensorrt/runtime/TensorRTBackend.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace tensorrt { + +/** + * TensorRT backend for executing models on NVIDIA GPUs. + * + * This backend deserializes TensorRT engines from blobs created by the + * Python preprocess() function and executes them using the TensorRT runtime. + */ +class TensorRTBackend final : public runtime::BackendInterface { + public: + TensorRTBackend() = default; + ~TensorRTBackend() override = default; + + /** + * Returns true if TensorRT is available on this device. + * + * Checks for: + * - TensorRT runtime library availability + * - CUDA device availability + */ + bool is_available() const override; + + /** + * Initialize the TensorRT backend with a serialized engine blob. + * + * Parses the blob header, extracts I/O binding metadata, and deserializes + * the TensorRT engine. Creates an execution context for inference. + * + * @param context Backend initialization context. + * @param processed Blob containing the serialized TensorRT engine. + * @param compile_specs Compilation specifications (unused at runtime). + * @return DelegateHandle pointer on success, error otherwise. + */ + runtime::Result init( + runtime::BackendInitContext& context, + runtime::FreeableBuffer* processed, + runtime::ArrayRef compile_specs) const override; + + /** + * Execute inference using the TensorRT engine. + * + * Binds input tensors from args to TensorRT input bindings, runs inference, + * and copies results to output tensors. + * + * @param context Backend execution context. + * @param handle DelegateHandle returned by init(). + * @param args Input and output EValues. + * @return Error::Ok on success. + */ + runtime::Error execute( + runtime::BackendExecutionContext& context, + runtime::DelegateHandle* handle, + runtime::Span args) const override; + + /** + * Destroy the delegate handle and release TensorRT resources. + * + * @param handle DelegateHandle to destroy. + */ + void destroy(runtime::DelegateHandle* handle) const override; +}; + +} // namespace tensorrt +} // namespace backends +} // namespace executorch diff --git a/backends/nvidia/tensorrt/runtime/TensorRTBlobHeader.h b/backends/nvidia/tensorrt/runtime/TensorRTBlobHeader.h new file mode 100644 index 00000000000..5276d00b440 --- /dev/null +++ b/backends/nvidia/tensorrt/runtime/TensorRTBlobHeader.h @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace tensorrt { + +/** + * Magic bytes identifying a TensorRT blob. + * "TR01" = TensorRT version 1 format with I/O metadata. + */ +constexpr char kTensorRTMagic[4] = {'T', 'R', '0', '1'}; + +/** + * Header size in bytes (32 bytes, 16-byte aligned). + */ +constexpr uint32_t kHeaderSize = 32; + +/** + * TensorRT blob header structure. + * + * Layout (little-endian, 32 bytes total): + * magic (4 bytes) - "TR01" + * metadata_offset (4 bytes) - offset to metadata JSON from start + * metadata_size (4 bytes) - size of metadata JSON in bytes + * engine_offset (4 bytes) - offset to engine data from start + * engine_size (8 bytes) - size of engine data in bytes + * reserved (8 bytes) - for future use + */ +struct TensorRTBlobHeader { + char magic[4]; + uint32_t metadata_offset; + uint32_t metadata_size; + uint32_t engine_offset; + uint64_t engine_size; + uint8_t reserved[8]; + + /** + * Check if this is a valid TensorRT blob header. + * + * @return true if magic bytes match "TR01". + */ + bool is_valid() const { + return std::memcmp(magic, kTensorRTMagic, 4) == 0; + } +}; + +static_assert(sizeof(TensorRTBlobHeader) == 32, "Header must be 32 bytes"); + +/** + * Parse a TensorRT blob header from raw bytes. + * + * @param data Pointer to blob data (must be at least kHeaderSize bytes). + * @param data_size Size of data buffer in bytes. + * @param out_header Output header structure. + * @return true if header was parsed successfully. + */ +inline bool parse_blob_header( + const void* data, + size_t data_size, + TensorRTBlobHeader& out_header) { + if (data == nullptr || data_size < kHeaderSize) { + return false; + } + + std::memcpy(&out_header, data, sizeof(TensorRTBlobHeader)); + return out_header.is_valid(); +} + +/** + * Get a pointer to the engine data within a blob. + * + * @param data Pointer to blob data. + * @param data_size Size of data buffer in bytes. + * @param header Parsed header from parse_blob_header(). + * @param out_engine Output pointer to engine data. + * @param out_engine_size Output size of engine data. + * @return true if engine data was located successfully. + */ +inline bool get_engine_from_blob( + const void* data, + size_t data_size, + const TensorRTBlobHeader& header, + const void*& out_engine, + size_t& out_engine_size) { + if (data == nullptr || !header.is_valid()) { + return false; + } + + const size_t end_offset = header.engine_offset + header.engine_size; + if (end_offset > data_size) { + return false; + } + + out_engine = static_cast(data) + header.engine_offset; + out_engine_size = static_cast(header.engine_size); + return true; +} + +/** + * Get a pointer to the metadata JSON within a blob. + * + * @param data Pointer to blob data. + * @param data_size Size of data buffer in bytes. + * @param header Parsed header from parse_blob_header(). + * @param out_metadata Output pointer to metadata JSON. + * @param out_metadata_size Output size of metadata JSON. + * @return true if metadata was located successfully. + */ +inline bool get_metadata_from_blob( + const void* data, + size_t data_size, + const TensorRTBlobHeader& header, + const void*& out_metadata, + size_t& out_metadata_size) { + if (data == nullptr || !header.is_valid()) { + return false; + } + + const size_t end_offset = header.metadata_offset + header.metadata_size; + if (end_offset > data_size) { + return false; + } + + out_metadata = static_cast(data) + header.metadata_offset; + out_metadata_size = static_cast(header.metadata_size); + return true; +} + +} // namespace tensorrt +} // namespace backends +} // namespace executorch diff --git a/backends/nvidia/tensorrt/runtime/TensorRTExecutor.cpp b/backends/nvidia/tensorrt/runtime/TensorRTExecutor.cpp new file mode 100644 index 00000000000..baa2fd47b5d --- /dev/null +++ b/backends/nvidia/tensorrt/runtime/TensorRTExecutor.cpp @@ -0,0 +1,445 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include + +namespace executorch { +namespace backends { +namespace tensorrt { + +using executorch::runtime::Error; + +namespace { + +class TensorRTLogger : public nvinfer1::ILogger { + public: + void log(Severity severity, const char* msg) noexcept override { + switch (severity) { + case Severity::kINTERNAL_ERROR: + case Severity::kERROR: + ET_LOG(Error, "TensorRT: %s", msg); + break; + case Severity::kWARNING: + case Severity::kINFO: + case Severity::kVERBOSE: + ET_LOG(Info, "TensorRT: %s", msg); + break; + default: + break; + } + } +}; + +TensorRTLogger& get_logger() { + static TensorRTLogger logger; + return logger; +} + +size_t get_dtype_size(nvinfer1::DataType dtype) { + switch (dtype) { + case nvinfer1::DataType::kFLOAT: + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kHALF: + case nvinfer1::DataType::kBF16: + return 2; + case nvinfer1::DataType::kINT8: + case nvinfer1::DataType::kBOOL: + case nvinfer1::DataType::kUINT8: + case nvinfer1::DataType::kFP8: + return 1; + case nvinfer1::DataType::kINT64: + return 8; + default: + return 4; + } +} + +} // namespace + +TensorRTExecutor::~TensorRTExecutor() { + free_gpu_buffers(); + if (stream_) { + cudaStreamDestroy(stream_); + stream_ = nullptr; + } + context_.reset(); + engine_.reset(); + runtime_.reset(); +} + +TensorRTExecutor::TensorRTExecutor(TensorRTExecutor&& other) noexcept + : runtime_(std::move(other.runtime_)), + engine_(std::move(other.engine_)), + context_(std::move(other.context_)), + stream_(other.stream_), + io_bindings_(std::move(other.io_bindings_)), + gpu_buffers_(std::move(other.gpu_buffers_)), + uses_unified_memory_(other.uses_unified_memory_) { + other.stream_ = nullptr; + other.uses_unified_memory_ = false; +} + +TensorRTExecutor& TensorRTExecutor::operator=(TensorRTExecutor&& other) noexcept { + if (this != &other) { + free_gpu_buffers(); + if (stream_) { + cudaStreamDestroy(stream_); + } + runtime_ = std::move(other.runtime_); + engine_ = std::move(other.engine_); + context_ = std::move(other.context_); + stream_ = other.stream_; + io_bindings_ = std::move(other.io_bindings_); + gpu_buffers_ = std::move(other.gpu_buffers_); + uses_unified_memory_ = other.uses_unified_memory_; + + other.stream_ = nullptr; + other.uses_unified_memory_ = false; + } + return *this; +} + +Error TensorRTExecutor::initialize( + const void* blob_data, + size_t blob_size) { + TensorRTBlobHeader header{}; + if (!parse_blob_header(blob_data, blob_size, header)) { + ET_LOG(Error, "Failed to parse TensorRT blob header"); + return Error::InvalidArgument; + } + + const void* engine_data = nullptr; + size_t engine_size = 0; + if (!get_engine_from_blob(blob_data, blob_size, header, engine_data, engine_size)) { + ET_LOG(Error, "Failed to extract engine from blob"); + return Error::InvalidArgument; + } + + const void* metadata_data = nullptr; + size_t metadata_size = 0; + if (header.metadata_size > 0) { + if (!get_metadata_from_blob(blob_data, blob_size, header, metadata_data, metadata_size)) { + ET_LOG(Info, "Failed to extract metadata from blob"); + } else if (metadata_data != nullptr && metadata_size > 0) { + parse_io_bindings(metadata_data, metadata_size); + } else { + ET_LOG(Info, "TensorRT blob metadata is empty"); + } + } + + // Initialize CUDA device before TensorRT + int device_count = 0; + cudaError_t cuda_err = cudaGetDeviceCount(&device_count); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "Failed to get CUDA device count: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidState; + } + if (device_count == 0) { + ET_LOG(Error, "No CUDA devices available"); + return Error::InvalidState; + } + + cuda_err = cudaSetDevice(0); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "Failed to set CUDA device: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidState; + } + + ET_LOG(Info, "CUDA initialized with %d device(s)", device_count); + + nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(get_logger()); + if (!runtime) { + ET_LOG(Error, "Failed to create TensorRT runtime"); + return Error::InvalidState; + } + runtime_.reset(runtime); + + if (engine_data == nullptr || engine_size == 0) { + ET_LOG(Error, "TensorRT engine data is invalid"); + return Error::InvalidArgument; + } + + nvinfer1::ICudaEngine* engine = + runtime->deserializeCudaEngine(engine_data, engine_size); + if (!engine) { + ET_LOG(Error, "Failed to deserialize TensorRT engine"); + return Error::InvalidState; + } + engine_.reset(engine); + + nvinfer1::IExecutionContext* context = engine->createExecutionContext(); + if (!context) { + ET_LOG(Error, "Failed to create TensorRT execution context"); + return Error::InvalidState; + } + context_.reset(context); + + // Detect unified memory (Jetson and other integrated GPUs) + cudaDeviceProp prop{}; + cuda_err = cudaGetDeviceProperties(&prop, 0); + if (cuda_err == cudaSuccess) { + uses_unified_memory_ = prop.integrated != 0; + if (uses_unified_memory_) { + ET_LOG(Info, "Detected integrated GPU with unified memory - skipping CPU-GPU copies"); + } + } + + // Create persistent CUDA stream + cudaStream_t stream; + cuda_err = cudaStreamCreate(&stream); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "Failed to create CUDA stream: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidState; + } + stream_ = stream; + + // Pre-allocate GPU buffers + // For unified memory (Jetson): use cudaMallocManaged + // For discrete GPUs: use cudaMalloc + auto alloc_err = allocate_gpu_buffers(); + if (alloc_err != Error::Ok) { + return alloc_err; + } + + ET_LOG(Info, "TensorRT executor initialized successfully"); + return Error::Ok; +} + +Error TensorRTExecutor::allocate_gpu_buffers() { + if (!engine_) { + return Error::InvalidState; + } + + const int32_t num_io_tensors = engine_->getNbIOTensors(); + + gpu_buffers_.clear(); + gpu_buffers_.reserve(static_cast(num_io_tensors)); + + size_t input_idx = 0; + size_t output_idx = 0; + + for (int32_t i = 0; i < num_io_tensors; ++i) { + const char* name = engine_->getIOTensorName(i); + const auto mode = engine_->getTensorIOMode(name); + const auto dims = engine_->getTensorShape(name); + const auto dtype = engine_->getTensorDataType(name); + + size_t num_elems = 1; + for (int d = 0; d < dims.nbDims; ++d) { + num_elems *= static_cast(dims.d[d]); + } + const size_t buffer_size = num_elems * get_dtype_size(dtype); + + void* gpu_buffer = nullptr; + cudaError_t cuda_err; + if (uses_unified_memory_) { + // Use managed memory for unified memory systems (Jetson) + // This memory is accessible from both CPU and GPU + cuda_err = cudaMallocManaged(&gpu_buffer, buffer_size); + } else { + // Use device memory for discrete GPUs + cuda_err = cudaMalloc(&gpu_buffer, buffer_size); + } + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "Failed to allocate GPU memory: %s", cudaGetErrorString(cuda_err)); + free_gpu_buffers(); + return Error::MemoryAllocationFailed; + } + + GPUBuffer buf; + buf.ptr = gpu_buffer; + buf.size = buffer_size; + buf.is_input = (mode == nvinfer1::TensorIOMode::kINPUT); + buf.tensor_index = i; + // Pre-compute the I/O index mapping to avoid runtime allocation + if (buf.is_input) { + buf.io_index = input_idx++; + } else { + buf.io_index = output_idx++; + } + gpu_buffers_.push_back(buf); + } + + ET_LOG(Info, "Pre-allocated %zu %s buffers", + gpu_buffers_.size(), + uses_unified_memory_ ? "managed memory" : "GPU"); + return Error::Ok; +} + +void TensorRTExecutor::free_gpu_buffers() { + for (auto& buf : gpu_buffers_) { + if (buf.ptr) { + cudaFree(buf.ptr); + buf.ptr = nullptr; + } + } + gpu_buffers_.clear(); +} + +Error TensorRTExecutor::execute( + void* const* input_buffers, + size_t num_inputs, + void* const* output_buffers, + size_t num_outputs) { + if (!is_initialized()) { + ET_LOG(Error, "Executor not initialized"); + return Error::InvalidState; + } + + // Validate buffer counts match expected (pre-computed during init) + size_t expected_inputs = 0; + size_t expected_outputs = 0; + for (const auto& buf : gpu_buffers_) { + if (buf.is_input) { + ++expected_inputs; + } else { + ++expected_outputs; + } + } + if (num_inputs < expected_inputs) { + ET_LOG(Error, "Not enough input buffers: got %zu, expected %zu", num_inputs, expected_inputs); + return Error::InvalidArgument; + } + if (num_outputs < expected_outputs) { + ET_LOG(Error, "Not enough output buffers: got %zu, expected %zu", num_outputs, expected_outputs); + return Error::InvalidArgument; + } + + if (uses_unified_memory_) { + // Unified memory path (Jetson): use cudaMallocManaged buffers + // We must copy data to/from managed memory because ExecuTorch's planned + // buffers are not CUDA-accessible. On unified memory systems, memcpy + // is very fast as it's just a CPU-side copy within shared physical memory. + for (const auto& buf : gpu_buffers_) { + const char* name = engine_->getIOTensorName(buf.tensor_index); + if (buf.is_input) { + std::memcpy(buf.ptr, input_buffers[buf.io_index], buf.size); + } + context_->setTensorAddress(name, buf.ptr); + } + + // Execute inference + bool success = context_->enqueueV3(stream_); + if (!success) { + ET_LOG(Error, "TensorRT inference failed"); + return Error::InvalidState; + } + + // Synchronize before reading outputs + cudaError_t cuda_err = cudaStreamSynchronize(stream_); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "CUDA synchronization failed: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidState; + } + + // Copy outputs from managed memory + for (const auto& buf : gpu_buffers_) { + if (!buf.is_input) { + std::memcpy(output_buffers[buf.io_index], buf.ptr, buf.size); + } + } + } else { + // Discrete GPU path: use pre-allocated GPU buffers with async copies + for (const auto& buf : gpu_buffers_) { + const char* name = engine_->getIOTensorName(buf.tensor_index); + if (buf.is_input) { + cudaError_t cuda_err = cudaMemcpyAsync( + buf.ptr, + input_buffers[buf.io_index], + buf.size, + cudaMemcpyHostToDevice, + stream_); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "Failed to copy input to GPU: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidState; + } + } + context_->setTensorAddress(name, buf.ptr); + } + + // Execute inference + bool success = context_->enqueueV3(stream_); + if (!success) { + ET_LOG(Error, "TensorRT inference failed"); + return Error::InvalidState; + } + + // Copy outputs from GPU to CPU + for (const auto& buf : gpu_buffers_) { + if (!buf.is_input) { + cudaError_t cuda_err = cudaMemcpyAsync( + output_buffers[buf.io_index], + buf.ptr, + buf.size, + cudaMemcpyDeviceToHost, + stream_); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "Failed to copy output from GPU: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidState; + } + } + } + + // Synchronize + cudaError_t cuda_err = cudaStreamSynchronize(stream_); + if (cuda_err != cudaSuccess) { + ET_LOG(Error, "CUDA synchronization failed: %s", cudaGetErrorString(cuda_err)); + return Error::InvalidState; + } + } + + return Error::Ok; +} + +bool TensorRTExecutor::parse_io_bindings(const void* json_data, size_t json_size) { + (void)json_data; + (void)json_size; + // TODO: Implement JSON parsing for I/O bindings + return true; +} + +size_t TensorRTExecutor::get_num_inputs() const { + if (!engine_) { + return 0; + } + const int32_t num_io_tensors = engine_->getNbIOTensors(); + size_t count = 0; + for (int32_t i = 0; i < num_io_tensors; ++i) { + const char* name = engine_->getIOTensorName(i); + if (engine_->getTensorIOMode(name) == nvinfer1::TensorIOMode::kINPUT) { + ++count; + } + } + return count; +} + +size_t TensorRTExecutor::get_num_outputs() const { + if (!engine_) { + return 0; + } + const int32_t num_io_tensors = engine_->getNbIOTensors(); + size_t count = 0; + for (int32_t i = 0; i < num_io_tensors; ++i) { + const char* name = engine_->getIOTensorName(i); + if (engine_->getTensorIOMode(name) == nvinfer1::TensorIOMode::kOUTPUT) { + ++count; + } + } + return count; +} + +} // namespace tensorrt +} // namespace backends +} // namespace executorch diff --git a/backends/nvidia/tensorrt/runtime/TensorRTExecutor.h b/backends/nvidia/tensorrt/runtime/TensorRTExecutor.h new file mode 100644 index 00000000000..20620fcfa59 --- /dev/null +++ b/backends/nvidia/tensorrt/runtime/TensorRTExecutor.h @@ -0,0 +1,176 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace tensorrt { + +/** + * I/O binding information for a TensorRT tensor. + */ +struct IOBinding { + std::string name; + std::string dtype; + std::vector shape; + bool is_input; +}; + +/** + * GPU buffer information for pre-allocated memory. + */ +struct GPUBuffer { + void* ptr{nullptr}; + size_t size{0}; + bool is_input{false}; + int32_t tensor_index{-1}; + size_t io_index{0}; // Index in input_buffers or output_buffers array +}; + +/** + * TensorRT executor for running inference with deserialized engines. + * + * This class wraps TensorRT runtime objects (engine, context) and provides + * a simple interface for executing inference. + * + * Memory management patterns: + * - TensorRT objects (IRuntime, ICudaEngine, IExecutionContext) use unique_ptr + * - CUDA stream uses raw pointer with explicit cleanup + * - GPU buffers are pre-allocated during initialize() + * - On unified memory systems (Jetson), CPU-GPU copies are skipped + */ +class TensorRTExecutor { + public: + TensorRTExecutor() = default; + ~TensorRTExecutor(); + + TensorRTExecutor(const TensorRTExecutor&) = delete; + TensorRTExecutor& operator=(const TensorRTExecutor&) = delete; + TensorRTExecutor(TensorRTExecutor&&) noexcept; + TensorRTExecutor& operator=(TensorRTExecutor&&) noexcept; + + /** + * Initialize the executor with a serialized blob. + * + * Parses the blob header, deserializes the TensorRT engine, creates + * an execution context, and pre-allocates GPU buffers. + * + * @param blob_data Pointer to the serialized blob. + * @param blob_size Size of the blob in bytes. + * @return Error::Ok on success. + */ + runtime::Error initialize( + const void* blob_data, + size_t blob_size); + + /** + * Execute inference with the given input/output buffers. + * + * On discrete GPUs: copies inputs to pre-allocated GPU memory, executes, + * and copies outputs back. + * On unified memory (Jetson): uses buffers directly without copies. + * + * @param input_buffers Array of pointers to input data buffers. + * @param num_inputs Number of input buffers. + * @param output_buffers Array of pointers to output data buffers. + * @param num_outputs Number of output buffers. + * @return Error::Ok on success. + */ + runtime::Error execute( + void* const* input_buffers, + size_t num_inputs, + void* const* output_buffers, + size_t num_outputs); + + /** + * Get I/O binding information. + * + * @return Vector of IOBinding structs describing inputs and outputs. + */ + const std::vector& get_io_bindings() const { + return io_bindings_; + } + + /** + * Check if the executor is initialized. + * + * @return true if initialized with a valid engine. + */ + bool is_initialized() const { + return engine_ != nullptr; + } + + /** + * Get the number of input tensors. + * + * @return Number of input tensors in the engine. + */ + size_t get_num_inputs() const; + + /** + * Get the number of output tensors. + * + * @return Number of output tensors in the engine. + */ + size_t get_num_outputs() const; + + /** + * Check if running on unified memory system (e.g., Jetson). + * + * @return true if CPU and GPU share memory. + */ + bool uses_unified_memory() const { + return uses_unified_memory_; + } + + private: + /** + * Parse I/O binding metadata from JSON. + * + * @param json_data Pointer to JSON data. + * @param json_size Size of JSON data in bytes. + * @return true on success. + */ + bool parse_io_bindings(const void* json_data, size_t json_size); + + /** + * Pre-allocate GPU buffers for all I/O tensors. + * + * @return Error::Ok on success. + */ + runtime::Error allocate_gpu_buffers(); + + /** + * Free all pre-allocated GPU buffers. + */ + void free_gpu_buffers(); + + std::unique_ptr runtime_; + std::unique_ptr engine_; + std::unique_ptr context_; + ::cudaStream_t stream_{nullptr}; + std::vector io_bindings_; + std::vector gpu_buffers_; + bool uses_unified_memory_{false}; +}; + +} // namespace tensorrt +} // namespace backends +} // namespace executorch diff --git a/backends/nvidia/tensorrt/runtime/targets.bzl b/backends/nvidia/tensorrt/runtime/targets.bzl new file mode 100644 index 00000000000..b13203f0ab8 --- /dev/null +++ b/backends/nvidia/tensorrt/runtime/targets.bzl @@ -0,0 +1,56 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets for the TensorRT C++ runtime backend. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.cxx_library( + name = "executor", + srcs = [ + "TensorRTExecutor.cpp", + ], + exported_headers = [ + "TensorRTBlobHeader.h", + "TensorRTExecutor.h", + ], + visibility = ["PUBLIC"], + compiler_flags = [ + "-frtti", + "-fexceptions", + ], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + ], + fbcode_deps = [ + "fbsource//third-party/TensorRT:nvinfer-lazy", + "//third-party-buck/platform010/build/cuda:cuda", + ], + ) + + runtime.cxx_library( + name = "tensorrt_backend", + srcs = [ + "TensorRTBackend.cpp", + ], + exported_headers = [ + "TensorRTBackend.h", + ], + visibility = ["PUBLIC"], + # Force include all symbols so the static backend registration runs. + link_whole = True, + compiler_flags = [ + "-Wno-global-constructors", + ], + deps = [ + "//executorch/runtime/backend:interface", + ":executor", + ], + fbcode_deps = [ + "fbsource//third-party/TensorRT:nvinfer-lazy", + "//third-party-buck/platform010/build/cuda:cuda", + ], + ) diff --git a/install_utils.py b/install_utils.py index 9ffaf8c33ff..3051f7f31fd 100644 --- a/install_utils.py +++ b/install_utils.py @@ -56,6 +56,35 @@ def is_cuda_available() -> bool: return False +def is_tensorrt_available() -> bool: + """ + Check if TensorRT is available on the system. + + Returns: + True if TensorRT headers or Python package are found, False otherwise. + """ + # Check for TensorRT Python package + try: + import tensorrt # noqa: F401 + + return True + except ImportError: + pass + + # Check for TensorRT headers (e.g. JetPack system install without pip package in venv) + import os + + for include_dir in [ + "/usr/include/aarch64-linux-gnu", + "/usr/include/x86_64-linux-gnu", + "/usr/include", + ]: + if os.path.exists(os.path.join(include_dir, "NvInfer.h")): + return True + + return False + + @functools.lru_cache(maxsize=1) def _get_cuda_version(): """ diff --git a/setup.py b/setup.py index f05951012e3..f54383ca243 100644 --- a/setup.py +++ b/setup.py @@ -717,6 +717,13 @@ def run(self): # noqa C901 f"-DQNN_SDK_ROOT={qnn_sdk_root}", ] + # Check if TensorRT is available, and if so, enable building the TRT + # backend by default. + if install_utils.is_tensorrt_available() and install_utils.is_cmake_option_on( + cmake_configuration_args, "EXECUTORCH_BUILD_TENSORRT", default=True + ): + cmake_configuration_args += ["-DEXECUTORCH_BUILD_TENSORRT=ON"] + with Buck2EnvironmentFixer(): # Generate the cmake cache from scratch to ensure that the cache state # is predictable.