-
Notifications
You must be signed in to change notification settings - Fork 659
New thread pool #4635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New thread pool #4635
Changes from all commits
05d889e
6f7a95e
eabaa73
09b5562
896a922
8c5d76c
7667106
ce62744
2613088
b8c70eb
1eeed51
df4aaa4
80a71ea
a2e6897
0691f35
31c5763
4904550
da114d5
1f6662c
fd1bc5c
ae0fb61
cc5b63b
1a91ad2
faa8863
5524623
d8b2d9b
4a67674
f9bca08
48c2100
7249d33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,253 @@ | ||||||||||||||||
| // Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||
| // | ||||||||||||||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||
| // you may not use this file except in compliance with the License. | ||||||||||||||||
| // You may obtain a copy of the License at | ||||||||||||||||
| // | ||||||||||||||||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||
| // | ||||||||||||||||
| // Unless required by applicable law or agreed to in writing, software | ||||||||||||||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||
| // See the License for the specific language governing permissions and | ||||||||||||||||
| // limitations under the License. | ||||||||||||||||
|
|
||||||||||||||||
| #include "dali/core/exec/thread_pool_base.h" | ||||||||||||||||
| #include <stdexcept> | ||||||||||||||||
| #include <thread> | ||||||||||||||||
|
|
||||||||||||||||
| namespace dali { | ||||||||||||||||
|
|
||||||||||||||||
| JobBase::~JobBase() noexcept(false) { | ||||||||||||||||
| if (total_tasks_ > 0 && !wait_completed_) { | ||||||||||||||||
| throw std::logic_error("The job is not empty, but hasn't been discarded or waited for."); | ||||||||||||||||
| } | ||||||||||||||||
| while (running_) | ||||||||||||||||
| std::this_thread::yield(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void JobBase::DoWait() { | ||||||||||||||||
| if (wait_started_) | ||||||||||||||||
rostan-t marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
| throw std::logic_error("This job has already been waited for."); | ||||||||||||||||
| wait_started_ = true; | ||||||||||||||||
|
|
||||||||||||||||
| if (total_tasks_ == 0) { | ||||||||||||||||
| wait_completed_ = true; | ||||||||||||||||
| return; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if (executor_ == nullptr) | ||||||||||||||||
| throw std::logic_error("This job hasn't been run - cannot wait for it."); | ||||||||||||||||
|
|
||||||||||||||||
| auto ready = [&]() { return num_pending_tasks_ == 0; }; | ||||||||||||||||
| if (ThreadPoolBase::this_thread_pool() != nullptr) { | ||||||||||||||||
| bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready); | ||||||||||||||||
| wait_completed_ = true; | ||||||||||||||||
| if (!result) | ||||||||||||||||
| throw std::logic_error("The thread pool was stopped"); | ||||||||||||||||
| } else { | ||||||||||||||||
| int old = num_pending_tasks_.load(); | ||||||||||||||||
| while (old != 0) { | ||||||||||||||||
| num_pending_tasks_.wait(old); | ||||||||||||||||
| old = num_pending_tasks_.load(); | ||||||||||||||||
| assert(old >= 0); | ||||||||||||||||
| } | ||||||||||||||||
| wait_completed_ = true; | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void JobBase::DoNotify() { | ||||||||||||||||
| num_pending_tasks_.notify_all(); | ||||||||||||||||
| (void)std::lock_guard(mtx_); | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. created a
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @greptileai The intention was to acquire the lock and immediately release it, thereby establishing proper sequencing for the reentrant variant of wait function.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand - this is a deliberate memory ordering technique. The However, this pattern is quite unusual and may confuse other developers who expect either:
Consider adding a comment to clarify the intent:
Suggested change
This makes it clear to future maintainers that the immediate release is intentional for synchronization purposes. |
||||||||||||||||
| cv_.notify_all(); | ||||||||||||||||
|
Comment on lines
+61
to
+62
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unusual memory ordering pattern - consider adding a comment This intentional lock/unlock pattern for happens-before relationship is uncommon. Add a comment:
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||||||
| // We need this second flag to avoid a race condition where the destructor is called between | ||||||||||||||||
| // decrementing num_pending_tasks_ and notification_ without excessive use of mutexes. | ||||||||||||||||
| // This must be the very last operation in the task function that touches `this`. | ||||||||||||||||
| running_ = false; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // Job //////////////////////////////////////////////////////////////////// | ||||||||||||||||
|
|
||||||||||||||||
| void Job::Run(ThreadPoolBase &tp, bool wait) { | ||||||||||||||||
| if (executor_ != nullptr) | ||||||||||||||||
| throw std::logic_error("This job has already been started."); | ||||||||||||||||
| executor_ = &tp; | ||||||||||||||||
| running_ = !tasks_.empty(); | ||||||||||||||||
| { | ||||||||||||||||
| auto batch = tp.BeginBulkAdd(); | ||||||||||||||||
| for (auto &x : tasks_) { | ||||||||||||||||
| batch.Add(std::move(x.second.func)); | ||||||||||||||||
| } | ||||||||||||||||
| int added = batch.Size(); | ||||||||||||||||
| if (added) { | ||||||||||||||||
| num_pending_tasks_ += added; | ||||||||||||||||
| running_ = true; | ||||||||||||||||
| } | ||||||||||||||||
| batch.Submit(); | ||||||||||||||||
| } | ||||||||||||||||
| if (wait && !tasks_.empty()) | ||||||||||||||||
| Wait(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void Job::Wait() { | ||||||||||||||||
| DoWait(); | ||||||||||||||||
|
|
||||||||||||||||
| // note - this vector is not allocated unless there were exceptions thrown | ||||||||||||||||
| std::vector<std::exception_ptr> errors; | ||||||||||||||||
| for (auto &x : tasks_) { | ||||||||||||||||
| if (x.second.error) | ||||||||||||||||
| errors.push_back(std::move(x.second.error)); | ||||||||||||||||
| } | ||||||||||||||||
| if (errors.size() == 1) | ||||||||||||||||
| std::rethrow_exception(errors[0]); | ||||||||||||||||
| else if (errors.size() > 1) | ||||||||||||||||
| throw MultipleErrors(std::move(errors)); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void Job::Discard() { | ||||||||||||||||
| if (executor_ != nullptr) | ||||||||||||||||
| throw std::logic_error("Cannot discard a job that has already been started"); | ||||||||||||||||
| tasks_.clear(); | ||||||||||||||||
| total_tasks_ = 0; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // IncrementalJob ///////////////////////////////////////////////////////// | ||||||||||||||||
|
|
||||||||||||||||
| void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) { | ||||||||||||||||
| if (executor_ && executor_ != &tp) | ||||||||||||||||
| throw std::logic_error("This job is already running in a different executor."); | ||||||||||||||||
| executor_ = &tp; | ||||||||||||||||
| { | ||||||||||||||||
| auto it = last_task_run_.has_value() ? std::next(*last_task_run_) : tasks_.begin(); | ||||||||||||||||
| auto batch = tp.BeginBulkAdd(); | ||||||||||||||||
| for (; it != tasks_.end(); ++it) { | ||||||||||||||||
| batch.Add(std::move(it->func)); | ||||||||||||||||
| last_task_run_ = it; | ||||||||||||||||
| } | ||||||||||||||||
| int added = batch.Size(); | ||||||||||||||||
| if (added) { | ||||||||||||||||
| num_pending_tasks_ += added; | ||||||||||||||||
| running_ = true; | ||||||||||||||||
| } | ||||||||||||||||
| batch.Submit(); | ||||||||||||||||
| } | ||||||||||||||||
| if (wait && !tasks_.empty()) | ||||||||||||||||
| Wait(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void IncrementalJob::Discard() { | ||||||||||||||||
| if (executor_) | ||||||||||||||||
| throw std::logic_error("Cannot discard a job that has already been started"); | ||||||||||||||||
| tasks_.clear(); | ||||||||||||||||
| total_tasks_ = 0; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void IncrementalJob::Wait() { | ||||||||||||||||
| DoWait(); | ||||||||||||||||
| // note - this vector is not allocated unless there were exceptions thrown | ||||||||||||||||
| std::vector<std::exception_ptr> errors; | ||||||||||||||||
| for (auto &x : tasks_) { | ||||||||||||||||
| if (x.error) | ||||||||||||||||
| errors.push_back(std::move(x.error)); | ||||||||||||||||
| } | ||||||||||||||||
| if (errors.size() == 1) | ||||||||||||||||
| std::rethrow_exception(errors[0]); | ||||||||||||||||
| else if (errors.size() > 1) | ||||||||||||||||
| throw MultipleErrors(std::move(errors)); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| /////////////////////////////////////////////////////////////////////////// | ||||||||||||||||
|
|
||||||||||||||||
| thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr; | ||||||||||||||||
| thread_local int ThreadPoolBase::this_thread_idx_ = -1; | ||||||||||||||||
|
|
||||||||||||||||
| void ThreadPoolBase::Init(int num_threads, const std::function<OnThreadStartFn> &on_thread_start) { | ||||||||||||||||
| if (shutdown_pending_) | ||||||||||||||||
| throw std::logic_error("The thread pool is being shut down."); | ||||||||||||||||
| std::lock_guard<std::mutex> g(mtx_); | ||||||||||||||||
| if (!threads_.empty()) | ||||||||||||||||
| throw std::logic_error("The thread pool is already started!"); | ||||||||||||||||
| threads_.reserve(num_threads); | ||||||||||||||||
| for (int i = 0; i < num_threads; i++) | ||||||||||||||||
| threads_.push_back(std::thread(&ThreadPoolBase::Run, this, i, on_thread_start)); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void ThreadPoolBase::Shutdown(bool join) { | ||||||||||||||||
| if ((shutdown_pending_ && !join) || threads_.empty()) | ||||||||||||||||
| return; | ||||||||||||||||
| { | ||||||||||||||||
| std::lock_guard<std::mutex> g(mtx_); | ||||||||||||||||
| if (shutdown_pending_ && !join) | ||||||||||||||||
| return; | ||||||||||||||||
| shutdown_pending_ = true; | ||||||||||||||||
| sem_.release(threads_.size()); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| for (auto &t : threads_) | ||||||||||||||||
| t.join(); | ||||||||||||||||
| threads_.clear(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void ThreadPoolBase::AddTaskNoLock(TaskFunc &&f) { | ||||||||||||||||
| if (shutdown_pending_) | ||||||||||||||||
| throw std::logic_error("The thread pool is stopped and no longer accepts new tasks."); | ||||||||||||||||
| tasks_.push(std::move(f)); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void ThreadPoolBase::AddTask(TaskFunc &&f) { | ||||||||||||||||
| { | ||||||||||||||||
| std::lock_guard<std::mutex> g(mtx_); | ||||||||||||||||
| AddTaskNoLock(std::move(f)); | ||||||||||||||||
| } | ||||||||||||||||
| sem_.release(1); | ||||||||||||||||
| } | ||||||||||||||||
mzient marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
|
|
||||||||||||||||
| void ThreadPoolBase::Run( | ||||||||||||||||
| int index, | ||||||||||||||||
| const std::function<OnThreadStartFn> &on_thread_start) noexcept { | ||||||||||||||||
| this_thread_pool_ = this; | ||||||||||||||||
| this_thread_idx_ = index; | ||||||||||||||||
| std::any scope; | ||||||||||||||||
| if (on_thread_start) | ||||||||||||||||
| scope = on_thread_start(index); | ||||||||||||||||
| while (!shutdown_pending_ || !tasks_.empty()) { | ||||||||||||||||
| sem_.acquire(); | ||||||||||||||||
| std::unique_lock lock(mtx_); | ||||||||||||||||
| if (shutdown_pending_) | ||||||||||||||||
| break; | ||||||||||||||||
| assert(!tasks_.empty() && "Semaphore acquired but no tasks present."); | ||||||||||||||||
| PopAndRunTask(lock); | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void ThreadPoolBase::PopAndRunTask(std::unique_lock<std::mutex> &lock) { | ||||||||||||||||
| TaskFunc t = std::move(tasks_.front()); | ||||||||||||||||
| tasks_.pop(); | ||||||||||||||||
| lock.unlock(); | ||||||||||||||||
| t(); | ||||||||||||||||
| lock.lock(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| template <typename Condition> | ||||||||||||||||
| bool ThreadPoolBase::WaitOrRunTasks(std::condition_variable &cv, Condition &&condition) { | ||||||||||||||||
| assert(this_thread_pool() == this); | ||||||||||||||||
| std::unique_lock lock(mtx_); | ||||||||||||||||
| while (!shutdown_pending_ || !tasks_.empty()) { | ||||||||||||||||
| bool ret; | ||||||||||||||||
| while (!(ret = condition()) && tasks_.empty()) | ||||||||||||||||
| cv.wait_for(lock, std::chrono::microseconds(100)); | ||||||||||||||||
|
|
||||||||||||||||
| if (ret || condition()) // re-evaluate the condition, just in case | ||||||||||||||||
| return true; | ||||||||||||||||
| if (shutdown_pending_) | ||||||||||||||||
| return condition(); | ||||||||||||||||
| if (!sem_.try_acquire()) | ||||||||||||||||
| continue; | ||||||||||||||||
|
|
||||||||||||||||
| assert(!tasks_.empty() && "Semaphore acquired but no tasks present."); | ||||||||||||||||
| PopAndRunTask(lock); | ||||||||||||||||
| } | ||||||||||||||||
| return condition(); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| } // namespace dali | ||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
busy-wait in destructor is CPU-intensive
The destructor spins while
running_is true. For long-running tasks, this burns CPU. Consider using a timed wait or documenting expected task duration.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@greptileai This never waits for more than the time between the notification of the atomic variable
num_pending_tasks_and setting therunning_flag_inDoNotify.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely correct - the busy-wait in the destructor only spins for the extremely brief window between the atomic notification of
num_pending_tasks_and the setting ofrunning_ = falseat the very end ofDoNotify().Looking at the implementation, this design is actually quite clever. The
running_flag serves as a secondary synchronization mechanism to avoid a race condition where the destructor could be called right afternum_pending_tasks_is decremented but before all the notifications are sent.The busy-wait duration is bounded to just a few CPU cycles - the time it takes to execute the notification calls and the final
running_ = falseassignment. This makes the spin time predictable and minimal, so the CPU overhead is negligible compared to the complexity of adding additional synchronization primitives.