diff --git a/dali/core/exec/thread_pool_base.cc b/dali/core/exec/thread_pool_base.cc new file mode 100644 index 00000000000..b729f89fe52 --- /dev/null +++ b/dali/core/exec/thread_pool_base.cc @@ -0,0 +1,253 @@ +// Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/core/exec/thread_pool_base.h" +#include +#include + +namespace dali { + +JobBase::~JobBase() noexcept(false) { + if (total_tasks_ > 0 && !wait_completed_) { + throw std::logic_error("The job is not empty, but hasn't been discarded or waited for."); + } + while (running_) + std::this_thread::yield(); +} + +void JobBase::DoWait() { + if (wait_started_) + throw std::logic_error("This job has already been waited for."); + wait_started_ = true; + + if (total_tasks_ == 0) { + wait_completed_ = true; + return; + } + + if (executor_ == nullptr) + throw std::logic_error("This job hasn't been run - cannot wait for it."); + + auto ready = [&]() { return num_pending_tasks_ == 0; }; + if (ThreadPoolBase::this_thread_pool() != nullptr) { + bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready); + wait_completed_ = true; + if (!result) + throw std::logic_error("The thread pool was stopped"); + } else { + int old = num_pending_tasks_.load(); + while (old != 0) { + num_pending_tasks_.wait(old); + old = num_pending_tasks_.load(); + assert(old >= 0); + } + wait_completed_ = true; + } +} + +void JobBase::DoNotify() { + num_pending_tasks_.notify_all(); + (void)std::lock_guard(mtx_); + cv_.notify_all(); + // We need this second flag to avoid a race condition where the destructor is called between + // decrementing num_pending_tasks_ and notification_ without excessive use of mutexes. + // This must be the very last operation in the task function that touches `this`. + running_ = false; +} + +// Job //////////////////////////////////////////////////////////////////// + +void Job::Run(ThreadPoolBase &tp, bool wait) { + if (executor_ != nullptr) + throw std::logic_error("This job has already been started."); + executor_ = &tp; + running_ = !tasks_.empty(); + { + auto batch = tp.BeginBulkAdd(); + for (auto &x : tasks_) { + batch.Add(std::move(x.second.func)); + } + int added = batch.Size(); + if (added) { + num_pending_tasks_ += added; + running_ = true; + } + batch.Submit(); + } + if (wait && !tasks_.empty()) + Wait(); +} + +void Job::Wait() { + DoWait(); + + // note - this vector is not allocated unless there were exceptions thrown + std::vector errors; + for (auto &x : tasks_) { + if (x.second.error) + errors.push_back(std::move(x.second.error)); + } + if (errors.size() == 1) + std::rethrow_exception(errors[0]); + else if (errors.size() > 1) + throw MultipleErrors(std::move(errors)); +} + +void Job::Discard() { + if (executor_ != nullptr) + throw std::logic_error("Cannot discard a job that has already been started"); + tasks_.clear(); + total_tasks_ = 0; +} + +// IncrementalJob ///////////////////////////////////////////////////////// + +void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) { + if (executor_ && executor_ != &tp) + throw std::logic_error("This job is already running in a different executor."); + executor_ = &tp; + { + auto it = last_task_run_.has_value() ? std::next(*last_task_run_) : tasks_.begin(); + auto batch = tp.BeginBulkAdd(); + for (; it != tasks_.end(); ++it) { + batch.Add(std::move(it->func)); + last_task_run_ = it; + } + int added = batch.Size(); + if (added) { + num_pending_tasks_ += added; + running_ = true; + } + batch.Submit(); + } + if (wait && !tasks_.empty()) + Wait(); +} + +void IncrementalJob::Discard() { + if (executor_) + throw std::logic_error("Cannot discard a job that has already been started"); + tasks_.clear(); + total_tasks_ = 0; +} + +void IncrementalJob::Wait() { + DoWait(); + // note - this vector is not allocated unless there were exceptions thrown + std::vector errors; + for (auto &x : tasks_) { + if (x.error) + errors.push_back(std::move(x.error)); + } + if (errors.size() == 1) + std::rethrow_exception(errors[0]); + else if (errors.size() > 1) + throw MultipleErrors(std::move(errors)); +} + +/////////////////////////////////////////////////////////////////////////// + +thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr; +thread_local int ThreadPoolBase::this_thread_idx_ = -1; + +void ThreadPoolBase::Init(int num_threads, const std::function &on_thread_start) { + if (shutdown_pending_) + throw std::logic_error("The thread pool is being shut down."); + std::lock_guard g(mtx_); + if (!threads_.empty()) + throw std::logic_error("The thread pool is already started!"); + threads_.reserve(num_threads); + for (int i = 0; i < num_threads; i++) + threads_.push_back(std::thread(&ThreadPoolBase::Run, this, i, on_thread_start)); +} + +void ThreadPoolBase::Shutdown(bool join) { + if ((shutdown_pending_ && !join) || threads_.empty()) + return; + { + std::lock_guard g(mtx_); + if (shutdown_pending_ && !join) + return; + shutdown_pending_ = true; + sem_.release(threads_.size()); + } + + for (auto &t : threads_) + t.join(); + threads_.clear(); +} + +void ThreadPoolBase::AddTaskNoLock(TaskFunc &&f) { + if (shutdown_pending_) + throw std::logic_error("The thread pool is stopped and no longer accepts new tasks."); + tasks_.push(std::move(f)); +} + +void ThreadPoolBase::AddTask(TaskFunc &&f) { + { + std::lock_guard g(mtx_); + AddTaskNoLock(std::move(f)); + } + sem_.release(1); +} + +void ThreadPoolBase::Run( + int index, + const std::function &on_thread_start) noexcept { + this_thread_pool_ = this; + this_thread_idx_ = index; + std::any scope; + if (on_thread_start) + scope = on_thread_start(index); + while (!shutdown_pending_ || !tasks_.empty()) { + sem_.acquire(); + std::unique_lock lock(mtx_); + if (shutdown_pending_) + break; + assert(!tasks_.empty() && "Semaphore acquired but no tasks present."); + PopAndRunTask(lock); + } +} + +void ThreadPoolBase::PopAndRunTask(std::unique_lock &lock) { + TaskFunc t = std::move(tasks_.front()); + tasks_.pop(); + lock.unlock(); + t(); + lock.lock(); +} + +template +bool ThreadPoolBase::WaitOrRunTasks(std::condition_variable &cv, Condition &&condition) { + assert(this_thread_pool() == this); + std::unique_lock lock(mtx_); + while (!shutdown_pending_ || !tasks_.empty()) { + bool ret; + while (!(ret = condition()) && tasks_.empty()) + cv.wait_for(lock, std::chrono::microseconds(100)); + + if (ret || condition()) // re-evaluate the condition, just in case + return true; + if (shutdown_pending_) + return condition(); + if (!sem_.try_acquire()) + continue; + + assert(!tasks_.empty() && "Semaphore acquired but no tasks present."); + PopAndRunTask(lock); + } + return condition(); +} + +} // namespace dali diff --git a/dali/core/exec/thread_pool_base_test.cc b/dali/core/exec/thread_pool_base_test.cc new file mode 100644 index 00000000000..fa005c9a24a --- /dev/null +++ b/dali/core/exec/thread_pool_base_test.cc @@ -0,0 +1,290 @@ +// Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "dali/core/exec/thread_pool_base.h" +#include "dali/core/format.h" +#include "dali/test/timing.h" + +namespace dali { + +struct SerialExecutor { + template + std::enable_if_t>> + AddTask(Runnable &&runnable) { + runnable(); + } +}; + +TEST(NewThreadPool, AddTask) { + ThreadPoolBase tp(4); + std::atomic_int flag{0}; + for (int i = 0; i < 16; i++) + tp.AddTask([&, i]() { + int f = (flag |= (1 << i)); + if (f == 0xffff) + flag.notify_all(); + }); + + int f = flag.load(); + while (f != 0xffff) { + flag.wait(f); + f = flag.load(); + } + // No conditions - this test succeeds if it doesn't hang +} + +TEST(NewThreadPool, BulkAddTask) { + ThreadPoolBase tp(4); + std::atomic_int flag{0}; + { + ThreadPoolBase::TaskBulkAdd bulk = tp.BeginBulkAdd(); + for (int i = 0; i < 16; i++) + bulk.Add([&, i]() { + int f = (flag |= (1 << i)); + if (f == 0xffff) + flag.notify_all(); + }); + EXPECT_EQ(bulk.Size(), 16); + // submitted automatically on destruction + } + + int f = flag.load(); + while (f != 0xffff) { + flag.wait(f); + f = flag.load(); + } + // No conditions - this test succeeds if it doesn't hang +} + +TEST(NewThreadPool, RunJobInThreadPool) { + Job job; + ThreadPoolBase tp(4); + int a = 0, b = 0, c = 0; + job.AddTask([&]() { + a = 1; + }); + job.AddTask([&]() { + b = 2; + }); + job.AddTask([&]() { + c = 3; + }); + job.Run(tp, true); + EXPECT_EQ(a, 1); + EXPECT_EQ(b, 2); + EXPECT_EQ(c, 3); +} + +TEST(NewThreadPool, RunIncrementalJobInThreadPool) { + ThreadPoolBase tp(4); + IncrementalJob job; + std::atomic_int a = 0, b = 0, c = 0; + job.AddTask([&]() { + a += 1; + }); + job.AddTask([&]() { + b += 2; + }); + job.Run(tp, false); + + for (int i = 0; (a.load() != 1 || b.load() != 2) && i < 100000; i++) + std::this_thread::sleep_for(std::chrono::microseconds(10)); + ASSERT_TRUE(a.load() == 1 && b.load() == 2) << "The job didn't start."; + + job.AddTask([&]() { + c += 3; + }); + job.Run(tp, true); + EXPECT_EQ(a.load(), 1); + EXPECT_EQ(b.load(), 2); + EXPECT_EQ(c.load(), 3); +} + +TEST(NewThreadPool, RunLargeIncrementalJobInThreadPool) { + ThreadPoolBase tp(4); + const int max_attempts = 10; + for (int attempt = 0; attempt < max_attempts; attempt++) { + IncrementalJob job; + std::atomic_int acc = 0; + const int total_tasks = 40000; + const int batch_size = 100; + for (int i = 0; i < total_tasks; i += batch_size) { + for (int j = i; j < i + batch_size; j++) { + job.AddTask([&, j] { + acc += j; + }); + } + job.Run(tp, false); + if (i == 0) { + for (int spin = 0; acc.load() == 0 && spin < 100000; spin++) + std::this_thread::sleep_for(std::chrono::microseconds(10)); + ASSERT_NE(acc.load(), 0) << "The job isn't running in the background."; + } + } + int target_value = total_tasks * (total_tasks - 1) / 2; + if (acc.load() == target_value) { + if (attempt == max_attempts - 1) { + FAIL() << "The job always finishes before a call to wait."; + } else { + std::cerr << "The job shouldn't have completed yet - retrying.\n"; + } + job.Wait(); + continue; + } + job.Run(tp, true); + EXPECT_EQ(acc.load(), target_value); + break; + } +} + +template +class NewThreadPoolJobTest : public ::testing::Test {}; + +using JobTypes = ::testing::Types; +TYPED_TEST_SUITE(NewThreadPoolJobTest, JobTypes); + + +TYPED_TEST(NewThreadPoolJobTest, RunJobInSeries) { + TypeParam job; + SerialExecutor tp; + int a = 0, b = 0, c = 0; + job.AddTask([&]() { + a = 1; + }); + job.AddTask([&]() { + b = 2; + }); + job.AddTask([&]() { + c = 3; + }); + job.Run(tp, true); + EXPECT_EQ(a, 1); + EXPECT_EQ(b, 2); + EXPECT_EQ(c, 3); +} + +TYPED_TEST(NewThreadPoolJobTest, Discard) { + EXPECT_NO_THROW({ + TypeParam job; + job.AddTask([]() {}); + job.Discard(); + }); +} + +TYPED_TEST(NewThreadPoolJobTest, ErrorIncrementalJobNotStarted) { + try { + TypeParam job; + job.AddTask([]() {}); + } catch (std::logic_error &e) { + EXPECT_NE(nullptr, strstr(e.what(), "The job is not empty")); + return; + } + GTEST_FAIL() << "Expected a logic error."; +} + +TYPED_TEST(NewThreadPoolJobTest, RethrowMultipleErrors) { + TypeParam job; + ThreadPoolBase tp(4); + job.AddTask([&]() { + throw std::runtime_error("Runtime"); + }); + job.AddTask([&]() { + // do nothing + }); + job.AddTask([&]() { + throw std::logic_error("Logic"); + }); + EXPECT_THROW(job.Run(tp, true), MultipleErrors); +} + +TYPED_TEST(NewThreadPoolJobTest, Reentrant) { + TypeParam job; + ThreadPoolBase tp(1); // must not hang with just one thread + std::atomic_int outer{0}, inner{0}; + for (int i = 0; i < 10; i++) { + job.AddTask([&, i]() { + outer |= (i << 10); + }); + } + + job.AddTask([&]() { + Job innerJob; + + for (int i = 0; i < 10; i++) + innerJob.AddTask([&, i]() { + inner |= (1 << i); + }); + + innerJob.Run(tp, false); + innerJob.Wait(); + outer |= (1 << 11); + }); + + for (int i = 11; i < 20; i++) { + job.AddTask([&, i]() { + outer |= (1 << i); + }); + } + job.Run(tp, true); +} + +TYPED_TEST(NewThreadPoolJobTest, JobPerf) { + using JobType = TypeParam; + ThreadPoolBase tp(4); + auto do_test = [&](int jobs, int tasks) { + std::vector v(tasks); + auto start = test::perf_timer::now(); + std::optional j; + for (int i = 0; i < jobs; i++) { + j.emplace(); + for (int t = 1; t < tasks; t++) { + j->AddTask([&, t]() { + v[t]++; + }); + } + j->Run(tp, false); + v[0]++; + j->Wait(); + j.reset(); + } + auto end = test::perf_timer::now(); + + for (int t = 0; t < tasks; t++) + EXPECT_EQ(v[t], jobs) << "Tasks didn't do their job"; + print( + std::cout, "Ran ", jobs, " jobs of ", tasks, " tasks each in ", + test::format_time(end - start), "\n"); + + return end - start; + }; + + int total_tasks = 100000; + int jobs0 = 10000, tasks0 = total_tasks / jobs0; + auto time0 = do_test(jobs0, tasks0); + int jobs1 = 100, tasks1 = total_tasks / jobs1; + auto time1 = do_test(jobs1, tasks1); + + // time0 = task_time * total_tasks + job_overhead * jobs0 + // time1 = task_time * total_tasks + job_overhead * jobs1 + // hence + // time0 - time1 = job_overhead * (jobs0 - jobs1) + // job_overhead = (time0 - time1) / (jobs0 - jobs1) + + double job_overhead = test::seconds(time0 - time1) / (jobs0 - jobs1); + print(std::cout, "Job overhead ", test::format_time(job_overhead), "\n"); +} + +} // namespace dali diff --git a/include/dali/core/exec/engine.h b/include/dali/core/exec/engine.h index 38b9ffbb050..a7d90c36d49 100644 --- a/include/dali/core/exec/engine.h +++ b/include/dali/core/exec/engine.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ class SequentialExecutionEngine { * @brief Immediately execute a callable object `f` with thread index 0. */ template - void AddWork(FunctionLike &&f, int64_t priority = 0, bool start_immediately = true) { + void AddWork(FunctionLike &&f, int64_t priority = 0) { const int idx = 0; // use of 0 literal would successfully call f expecting a pointer f(idx); } diff --git a/include/dali/core/exec/thread_pool_base.h b/include/dali/core/exec/thread_pool_base.h new file mode 100644 index 00000000000..93d81705d8e --- /dev/null +++ b/include/dali/core/exec/thread_pool_base.h @@ -0,0 +1,379 @@ +// Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_CORE_EXEC_THREAD_POOL_BASE_H_ +#define DALI_CORE_EXEC_THREAD_POOL_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "dali/core/api_helper.h" +#include "dali/core/format.h" +#include "dali/core/multi_error.h" +#include "dali/core/semaphore.h" +#include "dali/core/mm/detail/aux_alloc.h" + +namespace dali { + +class ThreadPoolBase; + +/** A base class for various job types. It defines common infrastructure. */ +class DLL_PUBLIC JobBase { + protected: + JobBase() = default; + ~JobBase() noexcept(false); + + /** Waits for all tasks to complete. Errors are NOT rethrown. + * + * NOTE: This function must not be inline and must be defined in the same dynamic shared object + * as the DoNotify function. + */ + void DoWait(); + + /** Notifies the job that all pending tasks have completed + * + * NOTE: This function must not be inline and must be defined in the same dynamic shared object + * as the DoWait function. + */ + void DoNotify(); + + // atomic wait has no timeout, so we're stuck with condvar for reentrance + std::mutex mtx_; + std::condition_variable cv_; + std::atomic_int num_pending_tasks_{0}; + std::atomic_bool running_{false}; + int total_tasks_ = 0; + bool wait_started_ = false; + bool wait_completed_ = false; + const void *executor_ = nullptr; + + struct Task { + std::function func; + std::exception_ptr error; + }; +}; + +/** + * @brief A collection of tasks, ordered by priority + * + * Tasks are added to a job first and then the entire work is scheduled as a whole. + * Once at least one task has been added, Run and Wait (or Discard) must be called + * before the task is destroyed. + */ +class DLL_PUBLIC Job final : public JobBase { + public: + ~Job() noexcept(false) = default; + + using priority_t = int64_t; + + template + std::enable_if_t>> + AddTask(Runnable &&runnable, priority_t priority = {}) { + if (wait_started_) + throw std::logic_error("This job has already been waited for - cannot add more tasks to it"); + + if (executor_ != nullptr) + throw std::logic_error("This job has already been started - cannot add more tasks to it"); + + auto it = tasks_.emplace(priority, Task()); + try { + it->second.func = [this, task = &it->second, f = std::move(runnable)]() noexcept { + try { + f(); + } catch (...) { + task->error = std::current_exception(); + } + if (--num_pending_tasks_ == 0) + DoNotify(); + }; + total_tasks_++; + } catch (...) { // if, for whatever reason, we cannot initialize the task, we should erase it + tasks_.erase(it); + throw; + } + } + + template + void Run(Executor &executor, bool wait); + + void Run(ThreadPoolBase &tp, bool wait); + + /** Waits for the job to complete. This function must be called only once. */ + void Wait(); + + void Discard(); + + private: + // This needs to be a container which never invalidates references when inserting new items. + std::multimap, + mm::detail::object_pool_allocator>> tasks_; +}; + +/** A job which can be extended with new tasks while already running. + * + * Unlike the regular `Job`, this job class doesn't prohibit adding new tasks after + * calling `Run`. It's still illegal to add new jobs while already waiting for completion. + * + * In this job, the tasks are processed strictly in FIFO order - there are no priorities. + * + * Calls to AddTask, Run and Wait are not thread safe and require external synchronization if + * called from different threads. + */ +class DLL_PUBLIC IncrementalJob final : public JobBase { + public: + ~IncrementalJob() noexcept(false) = default; + + template + std::enable_if_t>> + AddTask(Runnable &&runnable); + + template + void Run(Executor &executor, bool wait); + + void Run(ThreadPoolBase &tp, bool wait); + + /** Waits for the job to complete. This function must be called only once. + * + * After this call, adding more tasks is illegal. + */ + void Wait(); + + void Discard(); + + private: + using task_list_t = std::list>; + task_list_t tasks_; + std::optional last_task_run_; +}; + + +class DLL_PUBLIC ThreadPoolBase { + public: + using TaskFunc = std::function; + + ThreadPoolBase() = default; + explicit ThreadPoolBase(int num_threads) { + Init(num_threads); + } + + /** A function called upon thread start. + * + * @param thread_idx Index of the thread within this thread pool. + * @return A RAII object that lives until the thread's processing loop runs + * + * @note This callback doesn't explicitly take `this` pointer - if necessary, a lambda function + * can be used that captures the current thread pool instance. + */ + using OnThreadStartFn = std::any(int thread_idx); + + virtual void Init(int num_threads, const std::function &on_thread_start = {}); + + virtual ~ThreadPoolBase() { + Shutdown(true); + } + + void AddTask(TaskFunc &&f); + + void AddTaskNoLock(TaskFunc &&f); + + class TaskBulkAdd { + public: + void Add(TaskFunc &&f) { + if (!lock.owns_lock()) + lock.lock(); + owner->AddTaskNoLock(std::move(f)); + tasks_added++; + } + + ~TaskBulkAdd() { + Submit(); + } + + void Submit() { + if (lock.owns_lock()) { + lock.unlock(); + owner->sem_.release(tasks_added); + } + } + + int Size() const { + return tasks_added; + } + + private: + friend class ThreadPoolBase; + explicit TaskBulkAdd(ThreadPoolBase *o) : owner(o), lock(o->mtx_, std::defer_lock) {} + ThreadPoolBase *owner = nullptr; + std::unique_lock lock; + int tasks_added = 0; + }; + friend class TaskBulkAdd; + + TaskBulkAdd BeginBulkAdd() & { return TaskBulkAdd(this); } + + int NumThreads() const { + return threads_.size(); + } + + /** + * @brief Returns the thread pool that owns the calling thread (or nullptr) + */ + static ThreadPoolBase *this_thread_pool() { + return this_thread_pool_; + } + + /** + * @brief Returns the index of the current thread within the current thread pool + * + * @return the thread index or -1 if the calling thread does not belong to a thread pool + */ + static int this_thread_idx() { + return this_thread_idx_; + } + + protected: + void Shutdown(bool join); + + private: + friend class JobBase; + + template + bool WaitOrRunTasks(std::condition_variable &cv, Condition &&condition); + + void PopAndRunTask(std::unique_lock &mtx); + + static thread_local ThreadPoolBase *this_thread_pool_; + static thread_local int this_thread_idx_; + + void Run(int index, const std::function &on_thread_start) noexcept; + + std::mutex mtx_; + counting_semaphore sem_{0}; + bool shutdown_pending_ = false; + std::queue tasks_; + std::vector threads_; +}; + + +template +class ThreadedExecutionEngine { + public: + ThreadedExecutionEngine(ThreadPool &tp) : tp_(tp) {} // NOLINT + + template + void AddWork(FunctionLike &&f, int64_t priority = 0) { + job_.AddTask(std::forward(f), priority); + } + + void RunAll() { + job_.Run(tp_, true); + } + + int NumThreads() const noexcept { + return tp_.NumThreads(); + } + + ThreadPool &GetThreadPool() const noexcept { + return tp_; + } + + private: + ThreadPool &tp_; + Job job_; +}; + +template +void Job::Run(Executor &executor, bool wait) { + if constexpr (std::is_base_of_v) { + Run(static_cast(executor), wait); + } else { + if (executor_ != nullptr) + throw std::logic_error("This job has already been started."); + executor_ = &executor; + running_ = !tasks_.empty(); + for (auto &x : tasks_) { + num_pending_tasks_++; + try { + executor.AddTask(std::move(x.second.func)); + } catch (...) { + if (--num_pending_tasks_ == 0) + DoNotify(); + throw; + } + } + if (wait && !tasks_.empty()) + Wait(); + } +} + +template +std::enable_if_t>> +IncrementalJob::AddTask(Runnable &&runnable) { + if (wait_started_) + throw std::logic_error("This job has already been waited for - cannot add more tasks to it"); + + assert(executor_ == nullptr || executor_ != ThreadPoolBase::this_thread_pool()); + + auto it = tasks_.emplace(tasks_.end(), Task()); + try { + it->func = [this, task = &*it, f = std::move(runnable)]() noexcept { + try { + f(); + } catch (...) { + task->error = std::current_exception(); + } + + if (--num_pending_tasks_ == 0) + DoNotify(); + }; + total_tasks_++; + } catch (...) { // if, for whatever reason, we cannot initialize the task, we should erase it + tasks_.erase(it); + throw; + } +} + +template +void IncrementalJob::Run(Executor &executor, bool wait) { + if constexpr (std::is_base_of_v) { + Run(static_cast(executor), wait); + } else { + if (executor_ && executor_ != &executor) + throw std::logic_error("This job is already running in a different executor."); + executor_ = &executor; + auto it = last_task_run_.has_value() ? std::next(*last_task_run_) : tasks_.begin(); + for (; it != tasks_.end(); ++it) { + running_ = true; + num_pending_tasks_++; + executor.AddTask(std::move(it->func)); + last_task_run_ = it; + } + if (wait && !tasks_.empty()) + Wait(); + } +} + +} // namespace dali + +#endif // DALI_CORE_EXEC_THREAD_POOL_BASE_H_ diff --git a/include/dali/core/multi_error.h b/include/dali/core/multi_error.h new file mode 100644 index 00000000000..40c96a44743 --- /dev/null +++ b/include/dali/core/multi_error.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_CORE_MULTI_ERROR_H_ +#define DALI_CORE_MULTI_ERROR_H_ + +#include +#include +#include +#include +#include + +namespace dali { + +class MultipleErrors : public std::runtime_error { + public: + explicit MultipleErrors(std::vector errors) + : runtime_error(""), errors_(std::move(errors)) { + compose_message(); + } + + const char *what() const noexcept override { + return message_.c_str(); + } + + const std::vector &errors() const { + return errors_; + } + + private: + void compose_message() { + std::stringstream ss; + ss << "Multiple exceptions:\n"; + for (const auto &e : errors_) { + try { + std::rethrow_exception(e); + } catch (const std::exception &e) { + ss << typeid(e).name() << ": " << e.what() << "\n"; + } catch (...) { + ss << "Unknown exception\n"; + } + } + message_ = ss.str(); + } + + std::vector errors_; + std::string message_; +}; + +} // namespace dali + +#endif // DALI_CORE_MULTI_ERROR_H_