diff --git a/CMakeLists.txt b/CMakeLists.txt index e4db14887bd..4119db77fc5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -807,6 +807,12 @@ if(EXECUTORCH_BUILD_PTHREADPOOL AND EXECUTORCH_BUILD_CPUINFO) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/threadpool) endif() +# TensorRT examples (benchmark, etc.) need extension_module and extension_tensor, +# so they must be included after those targets are defined above. +if(EXECUTORCH_BUILD_TENSORRT) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples/nvidia) +endif() + if(EXECUTORCH_BUILD_KERNELS_TORCHAO) if(NOT TARGET cpuinfo) message( diff --git a/examples/nvidia/CMakeLists.txt b/examples/nvidia/CMakeLists.txt new file mode 100644 index 00000000000..088602e20d2 --- /dev/null +++ b/examples/nvidia/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# NVIDIA Backend Examples +# +# This CMakeLists.txt includes the TensorRT examples subdirectory. +# +# Supported platforms: +# - Linux x86_64 with NVIDIA GPU (devgpu, workstations) +# - NVIDIA Jetson (Orin Nano, AGX Orin, etc.) +# +# Build instructions: +# cmake .. -DEXECUTORCH_BUILD_TENSORRT=ON +# cmake --build . --target benchmark_runner_tensorrt + +cmake_minimum_required(VERSION 3.19) + +project(nvidia_examples) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Ensure compile_commands.json is generated for tooling +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +# Include utility CMake scripts from ExecuTorch +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +# Find CUDA (optional - needed for TensorRT backend) +find_package(CUDAToolkit QUIET) + +# Add TensorRT examples subdirectory +add_subdirectory(tensorrt) diff --git a/examples/nvidia/tensorrt/CMakeLists.txt b/examples/nvidia/tensorrt/CMakeLists.txt new file mode 100644 index 00000000000..1e40f64325e --- /dev/null +++ b/examples/nvidia/tensorrt/CMakeLists.txt @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TensorRT Examples - Benchmark runner +# +# Build: +# cmake -DEXECUTORCH_BUILD_TENSORRT=ON ... +# cmake --build . --target benchmark +# +# Usage: +# ./benchmark # all .pte/.onnx in current dir +# ./benchmark -m mv3 # mv3 .pte and .onnx in current dir +# ./benchmark -d /tmp/trt -n 200 # all models in /tmp/trt, 200 iterations + +cmake_minimum_required(VERSION 3.19) + +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +endif() + +set(COMMON_INCLUDE_DIRS ${EXECUTORCH_ROOT}/..) + +if(EXECUTORCH_BUILD_TENSORRT) + find_library(NVONNXPARSER_LIBRARY nvonnxparser + HINTS ${TENSORRT_HOME}/lib ${TENSORRT_HOME}/lib64 + /usr/lib /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu + ) + + add_executable(benchmark ${CMAKE_CURRENT_SOURCE_DIR}/benchmark.cpp) + + target_include_directories( + benchmark + PUBLIC $ + $ + ) + + # extension_module builds as extension_module_static in OSS CMake. + if(TARGET extension_module_static) + set(_extension_module extension_module_static) + elseif(TARGET extension_module) + set(_extension_module extension_module) + else() + message(FATAL_ERROR + "extension_module not found. Build with -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON") + endif() + + if(NOT TARGET extension_tensor) + message(FATAL_ERROR + "extension_tensor not found. Build with -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON") + endif() + + target_link_libraries( + benchmark + PRIVATE executorch + ${_extension_module} + extension_tensor + portable_kernels + ) + + target_link_options( + benchmark + PRIVATE + "SHELL:LINKER:--whole-archive $ LINKER:--no-whole-archive" + ) + target_link_libraries(benchmark PRIVATE CUDA::cudart) + if(TENSORRT_LIBRARY) + target_link_libraries(benchmark PRIVATE ${TENSORRT_LIBRARY}) + endif() + if(NVONNXPARSER_LIBRARY) + target_link_libraries(benchmark PRIVATE ${NVONNXPARSER_LIBRARY}) + endif() + add_dependencies(benchmark tensorrt_backend) + + target_compile_options(benchmark PRIVATE -frtti -fexceptions) + + install(TARGETS benchmark DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/examples/nvidia/tensorrt/README.md b/examples/nvidia/tensorrt/README.md index dd7ba1a99b8..4ec66f24e9c 100644 --- a/examples/nvidia/tensorrt/README.md +++ b/examples/nvidia/tensorrt/README.md @@ -18,8 +18,8 @@ Export a supported model to ExecuTorch format with TensorRT delegation: # Export the add model python -m executorch.examples.nvidia.tensorrt.export -m add -# Export with validation test -python -m executorch.examples.nvidia.tensorrt.export -m add --test +# Export all supported models to a directory +python -m executorch.examples.nvidia.tensorrt.export -o /tmp/trt # Export to a specific directory python -m executorch.examples.nvidia.tensorrt.export -m add -o ./output @@ -59,6 +59,7 @@ python -m executorch.examples.nvidia.tensorrt.export --help - `export.py` - Main export script for converting models to TensorRT format - `runner.py` - Python utilities for running and testing exported models +- `benchmark.cpp` - C++ benchmark runner for performance measurement - `tensorrt_executor_runner.cpp` - C++ executor runner for TensorRT models - `__init__.py` - Package initialization @@ -85,13 +86,31 @@ python -m executorch.examples.nvidia.tensorrt.export -m add --help Show help message ``` -### Validation Testing +## Benchmarking -The `--test` flag runs the exported model through the ExecuTorch runtime -and compares outputs against the PyTorch reference model: +Export models then benchmark with the C++ runner: ```bash -python -m executorch.examples.nvidia.tensorrt.export -m add --test +# Step 1: Export models +python -m executorch.examples.nvidia.tensorrt.export -o /tmp/trt + +# Step 2: Benchmark all exported models +./benchmark -d /tmp/trt + +# Benchmark a specific model +./benchmark -d /tmp/trt -m mv3 + +# Benchmark with custom iterations +./benchmark -d /tmp/trt -n 200 -w 5 +``` + +**Benchmark Options:** +``` +-d, --model_dir DIR Directory with .pte files (default: current dir) +-m, --model_name NAME Run only NAME_tensorrt.pte from the directory +-n, --num_executions N Number of timed iterations (default: 100) +-w, --warmup N Number of warmup runs (default: 3) +-v, --verbose Enable verbose logging ``` ## Adding New Models @@ -109,7 +128,10 @@ To add support for a new model: examples/nvidia/tensorrt/ ├── export.py # CLI export script using MODEL_NAME_TO_MODEL registry ├── runner.py # Python runtime utilities for testing +├── benchmark.cpp # C++ benchmark runner binary ├── tensorrt_executor_runner.cpp # C++ executor runner binary +├── tests/ # Correctness tests +│ └── test_export.py # Export + inference verification ├── __init__.py # Package exports └── README.md # This file ``` diff --git a/examples/nvidia/tensorrt/benchmark.cpp b/examples/nvidia/tensorrt/benchmark.cpp new file mode 100644 index 00000000000..0d9ec09e135 --- /dev/null +++ b/examples/nvidia/tensorrt/benchmark.cpp @@ -0,0 +1,685 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Benchmark runner for ExecuTorch models with TensorRT delegation. + * + * Discovers .pte and .onnx files in a directory and benchmarks each one. + * For .pte files, uses ExecuTorch Module API with the TensorRT backend. + * For .onnx files, compiles to a TRT engine and benchmarks natively. + * + * Usage: + * benchmark # all models in current dir + * benchmark -m mv3 # mv3_tensorrt.pte in current dir + * benchmark -d /tmp/trt # all models in /tmp/trt + * benchmark -d /tmp/trt -m mv3 # mv3 .pte and .onnx in /tmp/trt + * benchmark -n 200 -w 5 # 200 iterations, 5 warmup + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::extension::from_blob; +using executorch::extension::Module; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::MethodMeta; +using executorch::runtime::Tag; + +namespace { + +constexpr uint32_t kDefaultIterations = 100; +constexpr uint32_t kDefaultWarmup = 3; +constexpr const char* kPteSuffix = ".pte"; +constexpr const char* kOnnxSuffix = ".onnx"; +constexpr const char* kTrtSuffix = "_tensorrt.pte"; + +struct Args { + std::string model_dir = "."; + std::string model_name; + uint32_t iterations = kDefaultIterations; + uint32_t warmup = kDefaultWarmup; + bool verbose = false; +}; + +struct BenchmarkResult { + std::string name; + std::string format; + uint32_t iterations; + double avg_ms; + double total_ms; + bool success; + std::string error; +}; + +// --------------------------------------------------------------------------- +// TRT helpers +// --------------------------------------------------------------------------- + +class TrtLogger : public nvinfer1::ILogger { + public: + void log(Severity severity, const char* msg) noexcept override { + if (severity <= Severity::kWARNING) { + ET_LOG(Info, "TensorRT: %s", msg); + } + } +}; + +int64_t volume(const nvinfer1::Dims& dims) { + int64_t v = 1; + for (int i = 0; i < dims.nbDims; ++i) { + v *= dims.d[i]; + } + return v; +} + +size_t dtype_size(nvinfer1::DataType dt) { + switch (dt) { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kINT8: + return 1; + case nvinfer1::DataType::kBOOL: + return 1; + default: + return 4; + } +} + +// --------------------------------------------------------------------------- +// CLI +// --------------------------------------------------------------------------- + +void print_usage() { + printf( + "Usage: benchmark [options]\n" + "\n" + "Options:\n" + " -d, --model_dir DIR Directory with .pte/.onnx files (default: .)\n" + " -m, --model_name NAME Run only this model\n" + " -n, --num_executions N Timed iterations (default: %u)\n" + " -w, --warmup N Warmup runs (default: %u)\n" + " -v, --verbose Verbose logging\n" + " -h, --help Show this message\n", + kDefaultIterations, + kDefaultWarmup); +} + +bool parse_args(int argc, char** argv, Args& args) { + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "-h" || arg == "--help") { + print_usage(); + return false; + } else if (arg == "-v" || arg == "--verbose") { + args.verbose = true; + } else if ((arg == "-d" || arg == "--model_dir") && i + 1 < argc) { + args.model_dir = argv[++i]; + } else if ((arg == "-m" || arg == "--model_name") && i + 1 < argc) { + args.model_name = argv[++i]; + } else if ((arg == "-n" || arg == "--num_executions") && i + 1 < argc) { + args.iterations = static_cast(std::stoul(argv[++i])); + } else if ((arg == "-w" || arg == "--warmup") && i + 1 < argc) { + args.warmup = static_cast(std::stoul(argv[++i])); + } else { + fprintf(stderr, "Error: unknown argument '%s'\n", arg.c_str()); + print_usage(); + return false; + } + } + return true; +} + +// --------------------------------------------------------------------------- +// File discovery +// --------------------------------------------------------------------------- + +bool ends_with(const std::string& s, const char* suffix) { + size_t len = strlen(suffix); + return s.size() >= len && s.compare(s.size() - len, len, suffix) == 0; +} + +std::string stem(const std::string& path) { + auto slash = path.rfind('/'); + std::string filename = + (slash != std::string::npos) ? path.substr(slash + 1) : path; + if (ends_with(filename, kTrtSuffix)) { + return filename.substr(0, filename.size() - strlen(kTrtSuffix)); + } + if (ends_with(filename, kPteSuffix)) { + return filename.substr(0, filename.size() - strlen(kPteSuffix)); + } + if (ends_with(filename, kOnnxSuffix)) { + return filename.substr(0, filename.size() - strlen(kOnnxSuffix)); + } + return filename; +} + +std::vector find_models( + const std::string& dir, + const std::string& name) { + std::vector paths; + + if (!name.empty()) { + // Try specific suffixes for the named model. + for (const char* suffix : {kTrtSuffix, kPteSuffix, kOnnxSuffix}) { + std::string path = dir + "/" + name + suffix; + if (FILE* f = fopen(path.c_str(), "r")) { + fclose(f); + paths.push_back(path); + } + } + return paths; + } + + auto scan = [&](const std::string& d) { + DIR* dp = opendir(d.c_str()); + if (!dp) { + return; + } + while (auto* entry = readdir(dp)) { + std::string f = entry->d_name; + if (ends_with(f, kPteSuffix) || ends_with(f, kOnnxSuffix)) + paths.push_back(d + "/" + f); + } + closedir(dp); + }; + + scan(dir); + // buck2 test runs from fbcode/, so models may land there. + if (paths.empty()) { + scan(dir + "/fbcode"); + } + + std::sort(paths.begin(), paths.end()); + return paths; +} + +// --------------------------------------------------------------------------- +// PTE benchmark (ExecuTorch Module API) +// --------------------------------------------------------------------------- + +BenchmarkResult benchmark_pte( + const std::string& path, + uint32_t iterations, + uint32_t warmup, + bool verbose) { + BenchmarkResult result{stem(path), "pte", iterations, 0, 0, false, ""}; + + Module module(path, Module::LoadMode::File); + + auto meta_result = module.method_meta("forward"); + if (!meta_result.ok()) { + result.error = "Failed to get method metadata"; + return result; + } + auto meta = meta_result.get(); + + // Create ones-filled inputs matching the model's expected shapes and dtypes. + std::vector input_tensors; + std::vector> float_data; + std::vector> int64_data; + std::vector> bf16_data; + std::vector inputs; + for (size_t i = 0; i < meta.num_inputs(); ++i) { + auto tag = meta.input_tag(i); + if (tag.ok() && tag.get() == Tag::Tensor) { + auto tmeta = meta.input_tensor_meta(i); + if (tmeta.ok()) { + auto sizes = tmeta.get().sizes(); + auto dtype = tmeta.get().scalar_type(); + auto sizes_vec = std::vector( + sizes.begin(), sizes.end()); + int64_t numel = 1; + for (auto s : sizes_vec) { + numel *= s; + } + + executorch::extension::TensorPtr tensor; + if (dtype == ScalarType::Long || dtype == ScalarType::Int) { + int64_data.emplace_back(numel, 1); + tensor = executorch::extension::make_tensor_ptr( + sizes_vec, int64_data.back()); + } else if (dtype == ScalarType::BFloat16) { + // BFloat16 1.0f = 0x3F80 + bf16_data.emplace_back(numel, 0x3F80); + tensor = executorch::extension::make_tensor_ptr( + sizes_vec, + bf16_data.back().data(), + executorch::aten::ScalarType::BFloat16); + } else { + float_data.emplace_back(numel, 1.0f); + tensor = executorch::extension::make_tensor_ptr( + sizes_vec, float_data.back()); + } + inputs.push_back(EValue(*tensor)); + input_tensors.push_back(std::move(tensor)); + } + } + } + + printf(" warming up ...\r"); + fflush(stdout); + for (uint32_t i = 0; i < warmup; ++i) { + auto r = module.forward(inputs); + if (!r.ok()) { + result.error = "Forward failed during warmup"; + return result; + } + } + + et_timestamp_t total = 0; + for (uint32_t i = 0; i < iterations; ++i) { + printf(" [%u/%u]\r", i + 1, iterations); + fflush(stdout); + auto start = executorch::runtime::pal_current_ticks(); + auto r = module.forward(inputs); + auto end = executorch::runtime::pal_current_ticks(); + if (!r.ok()) { + result.error = "Forward failed at iteration " + std::to_string(i); + return result; + } + total += end - start; + } + printf(" \r"); + + auto ratio = et_pal_ticks_to_ns_multiplier(); + result.total_ms = static_cast(total) * ratio.numerator / + ratio.denominator / 1000000.0; + result.avg_ms = result.total_ms / iterations; + result.success = true; + return result; +} + +// --------------------------------------------------------------------------- +// Raw TRT benchmark from PTE (extracts engine from delegate blob) +// --------------------------------------------------------------------------- + +BenchmarkResult benchmark_pte_raw_trt( + const std::string& path, + uint32_t iterations, + uint32_t warmup, + bool verbose) { + BenchmarkResult result{stem(path), "pte-raw", iterations, 0, 0, false, ""}; + + FILE* f = fopen(path.c_str(), "rb"); + if (!f) { + result.error = "Cannot open file"; + return result; + } + fseek(f, 0, SEEK_END); + size_t file_size = ftell(f); + fseek(f, 0, SEEK_SET); + std::vector file_data(file_size); + size_t bytes_read = fread(file_data.data(), 1, file_size, f); + if (bytes_read != file_size) { + fclose(f); + return {}; + } + fclose(f); + + // Search for our TRT blob header magic "TR01" in the PTE flatbuffer. + // Blob: magic(4) + meta_offset(4) + meta_size(4) + engine_offset(4) + + // engine_size(8) + reserved(8) = 32 bytes. + const char kMagic[4] = {'T', 'R', '0', '1'}; + const void* engine_data = nullptr; + size_t engine_size = 0; + + for (size_t i = 0; i + 32 < file_size; ++i) { + if (memcmp(file_data.data() + i, kMagic, 4) == 0) { + const auto* hdr = reinterpret_cast(file_data.data() + i); + uint32_t eng_offset = 0; + uint64_t eng_size = 0; + memcpy(&eng_offset, hdr + 12, 4); + memcpy(&eng_size, hdr + 16, 8); + if (eng_size > 0 && i + eng_offset + eng_size <= file_size) { + engine_data = file_data.data() + i + eng_offset; + engine_size = static_cast(eng_size); + break; + } + } + } + + if (!engine_data || engine_size == 0) { + result.error = "Cannot find TRT engine in PTE file"; + return result; + } + + TrtLogger logger; + auto runtime = std::unique_ptr( + nvinfer1::createInferRuntime(logger)); + auto engine = std::unique_ptr( + runtime->deserializeCudaEngine(engine_data, engine_size)); + if (!engine) { + result.error = "Failed to deserialize TRT engine from PTE"; + return result; + } + + auto context = std::unique_ptr( + engine->createExecutionContext()); + + cudaStream_t stream; + cudaStreamCreate(&stream); + std::vector buffers; + + for (int i = 0; i < engine->getNbIOTensors(); ++i) { + const auto* name = engine->getIOTensorName(i); + auto shape = engine->getTensorShape(name); + auto dt = engine->getTensorDataType(name); + size_t bytes = static_cast(volume(shape)) * dtype_size(dt); + void* buf = nullptr; + cudaMalloc(&buf, bytes); + cudaMemset(buf, 0, bytes); + context->setTensorAddress(name, buf); + buffers.push_back(buf); + } + + printf(" warming up ...\r"); + fflush(stdout); + for (uint32_t i = 0; i < warmup; ++i) { + context->enqueueV3(stream); + cudaStreamSynchronize(stream); + } + + et_timestamp_t total = 0; + for (uint32_t i = 0; i < iterations; ++i) { + printf(" [%u/%u]\r", i + 1, iterations); + fflush(stdout); + auto start = executorch::runtime::pal_current_ticks(); + context->enqueueV3(stream); + cudaStreamSynchronize(stream); + auto end = executorch::runtime::pal_current_ticks(); + total += end - start; + } + printf(" \r"); + + for (const auto& buf : buffers) { + cudaFree(buf); + } + cudaStreamDestroy(stream); + + auto ratio = et_pal_ticks_to_ns_multiplier(); + result.total_ms = static_cast(total) * ratio.numerator / + ratio.denominator / 1000000.0; + result.avg_ms = result.total_ms / iterations; + result.success = true; + return result; +} + +// --------------------------------------------------------------------------- +// ONNX benchmark (TensorRT native) +// --------------------------------------------------------------------------- + +BenchmarkResult benchmark_onnx( + const std::string& path, + uint32_t iterations, + uint32_t warmup, + bool verbose) { + BenchmarkResult result{stem(path), "onnx-trt", iterations, 0, 0, false, ""}; + + TrtLogger logger; + + auto builder = + std::unique_ptr(nvinfer1::createInferBuilder(logger)); + if (!builder) { + result.error = "Failed to create TRT builder"; + return result; + } + + auto network = std::unique_ptr( + builder->createNetworkV2( + 1 << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))); + auto parser = std::unique_ptr( + nvonnxparser::createParser(*network, logger)); + + if (verbose) { + + ET_LOG(Info, "Parsing ONNX: %s", path.c_str()); + + } + + if (!parser->parseFromFile( + path.c_str(), + static_cast(nvinfer1::ILogger::Severity::kWARNING))) { + result.error = "Failed to parse ONNX file"; + return result; + } + + auto config = std::unique_ptr( + builder->createBuilderConfig()); + config->setMemoryPoolLimit( + nvinfer1::MemoryPoolType::kWORKSPACE, 1ULL << 30); + // Match our backend's precision: strict FP32, no TF32. + if (config->getFlag(nvinfer1::BuilderFlag::kTF32)) { + config->clearFlag(nvinfer1::BuilderFlag::kTF32); + } + + if (verbose) { + + ET_LOG(Info, "Building TRT engine from ONNX..."); + + } + + auto plan = std::unique_ptr( + builder->buildSerializedNetwork(*network, *config)); + if (!plan) { + result.error = "Failed to build TRT engine"; + return result; + } + + auto runtime = std::unique_ptr( + nvinfer1::createInferRuntime(logger)); + auto engine = std::unique_ptr( + runtime->deserializeCudaEngine(plan->data(), plan->size())); + if (!engine) { + result.error = "Failed to deserialize engine"; + return result; + } + + auto context = std::unique_ptr( + engine->createExecutionContext()); + + // Allocate GPU buffers for all I/O tensors. + cudaStream_t stream; + cudaStreamCreate(&stream); + std::vector buffers; + + for (int i = 0; i < engine->getNbIOTensors(); ++i) { + auto name = engine->getIOTensorName(i); + auto shape = engine->getTensorShape(name); + auto dt = engine->getTensorDataType(name); + size_t bytes = static_cast(volume(shape)) * dtype_size(dt); + void* buf = nullptr; + cudaMalloc(&buf, bytes); + cudaMemset(buf, 0, bytes); + context->setTensorAddress(name, buf); + buffers.push_back(buf); + } + + printf(" warming up ...\r"); + fflush(stdout); + for (uint32_t i = 0; i < warmup; ++i) { + context->enqueueV3(stream); + cudaStreamSynchronize(stream); + } + + et_timestamp_t total = 0; + for (uint32_t i = 0; i < iterations; ++i) { + printf(" [%u/%u]\r", i + 1, iterations); + fflush(stdout); + auto start = executorch::runtime::pal_current_ticks(); + context->enqueueV3(stream); + cudaStreamSynchronize(stream); + auto end = executorch::runtime::pal_current_ticks(); + total += end - start; + } + printf(" \r"); + + for (auto buf : buffers) { + + cudaFree(buf); + + } + cudaStreamDestroy(stream); + + auto ratio = et_pal_ticks_to_ns_multiplier(); + result.total_ms = static_cast(total) * ratio.numerator / + ratio.denominator / 1000000.0; + result.avg_ms = result.total_ms / iterations; + result.success = true; + return result; +} + +// --------------------------------------------------------------------------- +// Output +// --------------------------------------------------------------------------- + +void print_summary(const std::vector& results) { + printf("\n"); + printf( + "%-20s %-10s %6s %10s %10s %s\n", + "MODEL", + "FORMAT", + "RUNS", + "AVG (ms)", + "TOTAL (ms)", + "STATUS"); + printf( + "%-20s %-10s %6s %10s %10s %s\n", + "--------------------", + "----------", + "------", + "----------", + "----------", + "------"); + + for (const auto& r : results) { + if (r.success) { + printf( + "%-20s %-10s %6u %10.3f %10.3f OK\n", + r.name.c_str(), + r.format.c_str(), + r.iterations, + r.avg_ms, + r.total_ms); + } else { + printf( + "%-20s %-10s %6s %10s %10s FAIL: %s\n", + r.name.c_str(), + r.format.c_str(), + "-", + "-", + "-", + r.error.c_str()); + } + } + printf("\n"); +} + +} // namespace + +int main(int argc, char** argv) { + executorch::runtime::runtime_init(); + + Args args; + if (!parse_args(argc, argv, args)) { + return 1; + } + + auto files = find_models(args.model_dir, args.model_name); + if (files.empty()) { + if (!args.model_name.empty()) { + fprintf( + stderr, + "Error: model '%s' not found in '%s'\n", + args.model_name.c_str(), + args.model_dir.c_str()); + } else { + fprintf( + stderr, + "Error: no .pte/.onnx files found in '%s'\n", + args.model_dir.c_str()); + } + return 1; + } + + if (args.verbose) { + ET_LOG( + Info, + "Found %zu model(s), warmup=%u, iterations=%u", + files.size(), + args.warmup, + args.iterations); + } + + std::vector results; + for (const auto& path : files) { + printf("Benchmarking: %s ...\n", path.c_str()); + + BenchmarkResult result; + if (ends_with(path, kOnnxSuffix)) { + result = benchmark_onnx(path, args.iterations, args.warmup, args.verbose); + } else { + result = benchmark_pte(path, args.iterations, args.warmup, args.verbose); + } + + if (result.success) { + printf(" %.3f ms avg (%u iterations)\n", result.avg_ms, result.iterations); + } else { + printf(" FAILED: %s\n", result.error.c_str()); + } + results.push_back(std::move(result)); + + // Also benchmark raw TRT engine extracted from PTE for overhead analysis. + if (ends_with(path, kPteSuffix)) { + printf("Benchmarking: %s (raw TRT) ...\n", path.c_str()); + auto raw_result = benchmark_pte_raw_trt( + path, args.iterations, args.warmup, args.verbose); + if (raw_result.success) { + printf( + " %.3f ms avg (%u iterations)\n", + raw_result.avg_ms, + raw_result.iterations); + } else { + printf(" FAILED: %s\n", raw_result.error.c_str()); + } + results.push_back(std::move(raw_result)); + } + } + + if (results.size() > 1) { + + print_summary(results); + + } + + return 0; +} diff --git a/examples/nvidia/tensorrt/export.py b/examples/nvidia/tensorrt/export.py index 8099f376280..5cb0f8607c0 100644 --- a/examples/nvidia/tensorrt/export.py +++ b/examples/nvidia/tensorrt/export.py @@ -110,6 +110,36 @@ def export_model( return model, example_inputs, exec_prog +def export_onnx( + model: torch.nn.Module, + example_inputs: tuple, + model_name: str, + output_dir: str, + logger: logging.Logger, +) -> None: + """Export model to ONNX format for baseline TRT benchmarking.""" + import os + try: + import onnx as _onnx # noqa: F401 + except ImportError: + raise RuntimeError( + "ONNX export requires the 'onnx' package. Install with: pip install onnx" + ) + onnx_path = os.path.join(output_dir, f"{model_name}.onnx") + os.makedirs(output_dir, exist_ok=True) + logging.info(f"Exporting {model_name} to ONNX: {onnx_path}") + # dynamo=False uses the legacy TorchScript-based exporter which doesn't + # require the onnxscript package. + torch.onnx.export( + model, + example_inputs, + onnx_path, + opset_version=17, + dynamo=False, + ) + logger.info(f"ONNX model saved to {onnx_path}") + + # --------------------------------------------------------------------------- # Correctness verification (used by test_export.py via buck test) # --------------------------------------------------------------------------- @@ -247,6 +277,18 @@ def main() -> None: default=True, help="Disable strict mode for export (default: strict mode enabled)", ) + parser.add_argument( + "--onnx", + action="store_true", + default=False, + help="Also export models to ONNX format for baseline TRT benchmarking", + ) + parser.add_argument( + "--onnx-only", + action="store_true", + default=False, + help="Export only ONNX format (skip .pte export)", + ) parser.add_argument( "-v", "--verbose", @@ -264,7 +306,20 @@ def main() -> None: failed = [] for model_name in models: try: - export_model(model_name, args.output_dir, args.strict, logger) + if args.onnx_only: + logging.info(f"Creating model: {model_name}") + torch.manual_seed(0) + model, example_inputs, _, _ = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[model_name] + ) + model.eval() + export_onnx(model, example_inputs, model_name, args.output_dir, logger) + else: + model, example_inputs, _ = export_model( + model_name, args.output_dir, args.strict, logger + ) + if args.onnx: + export_onnx(model, example_inputs, model_name, args.output_dir, logger) except Exception as e: logging.error(f"Failed to export {model_name}: {e}") failed.append(model_name) diff --git a/examples/nvidia/tensorrt/targets.bzl b/examples/nvidia/tensorrt/targets.bzl index dd3c15358ff..fb1c8c74b5b 100644 --- a/examples/nvidia/tensorrt/targets.bzl +++ b/examples/nvidia/tensorrt/targets.bzl @@ -1,4 +1,5 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "runtime") +load("@fbcode_macros//build_defs:cpp_binary.bzl", "cpp_binary") def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -39,6 +40,7 @@ def define_common_targets(): ], deps = [ ":tensorrt_example_lib", + "fbsource//third-party/onnx:onnx_py", # @manual for --onnx flag ], visibility = ["PUBLIC"], ) @@ -66,3 +68,24 @@ def define_common_targets(): "//executorch/kernels/portable:generated_lib", ], ) + + # Benchmarks .pte files (ExecuTorch Module API) and .onnx files (TRT native). + # benchmark # all models in current dir + # benchmark -m mv3 # mv3 .pte and .onnx in current dir + # benchmark -d /tmp/trt -n 200 # all models in /tmp/trt, 200 iterations + cpp_binary( + name = "benchmark", + srcs = ["benchmark.cpp"], + compiler_flags = ["-Wno-global-constructors"], + deps = [ + "//executorch/extension/module:module", + "//executorch/extension/tensor:tensor", + "//executorch/kernels/portable:generated_lib", + "//executorch/backends/nvidia/tensorrt/runtime:tensorrt_backend", + "fbsource//third-party/TensorRT:nvinfer-lazy", + "fbsource//third-party/TensorRT:nvonnxparser-lazy", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) diff --git a/examples/nvidia/tensorrt/tests/test_export.py b/examples/nvidia/tensorrt/tests/test_export.py index 81d50b6782b..2906f56b01e 100644 --- a/examples/nvidia/tensorrt/tests/test_export.py +++ b/examples/nvidia/tensorrt/tests/test_export.py @@ -23,12 +23,9 @@ # Mapping from env var to expected cache filename. # The test TARGETS provides these via manifold_get + $(location). -# Entries are added as models are enabled in later commits. _WEIGHT_ENV_VARS = { - "DOG_JPG": "dog.jpg", "EDSR_WEIGHTS": "edsr64_x2.pt", - "IC4_WEIGHTS": "inceptionv4-8e4777a0.pth", - "MV2_WEIGHTS": "mobilenet_v2-b0353104.pth", + "MV3_WEIGHTS": "mobilenet_v3_small-047dcff4.pth", }