Refactor thread pool to dynamic threads mode (#26407)

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
pull/26381/head
zhagnlu 2023-08-18 11:58:19 +08:00 committed by GitHub
parent 50a77ef1f7
commit bdc8c507ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 13 deletions

View File

@ -20,21 +20,80 @@ namespace milvus {
void
ThreadPool::Init() {
for (int i = 0; i < threads_.size(); i++) {
threads_[i] = std::thread(Worker(this, i));
std::lock_guard<std::mutex> lock(mutex_);
for (int i = 0; i < min_threads_size_; i++) {
std::thread t(&ThreadPool::Worker, this);
assert(threads_.find(t.get_id()) == threads_.end());
threads_[t.get_id()] = std::move(t);
current_threads_size_++;
}
}
void
ThreadPool::ShutDown() {
LOG_SEGCORE_INFO_ << "Start shutting down " << name_;
shutdown_ = true;
{
std::lock_guard<std::mutex> lock(mutex_);
shutdown_ = true;
}
condition_lock_.notify_all();
for (int i = 0; i < threads_.size(); i++) {
if (threads_[i].joinable()) {
threads_[i].join();
for (auto iter = threads_.begin(); iter != threads_.end(); ++iter) {
if (iter->second.joinable()) {
iter->second.join();
}
}
LOG_SEGCORE_INFO_ << "Finish shutting down " << name_;
}
void
ThreadPool::FinishThreads() {
while (!need_finish_threads_.empty()) {
std::thread::id id;
auto dequeue = need_finish_threads_.dequeue(id);
if (dequeue) {
auto iter = threads_.find(id);
assert(iter != threads_.end());
if (iter->second.joinable()) {
iter->second.join();
}
threads_.erase(iter);
}
}
}
void
ThreadPool::Worker() {
std::function<void()> func;
bool dequeue;
while (!shutdown_) {
std::unique_lock<std::mutex> lock(mutex_);
idle_threads_size_++;
auto is_timeout = !condition_lock_.wait_for(
lock, std::chrono::seconds(WAIT_SECONDS), [this]() {
return shutdown_ || !work_queue_.empty();
});
idle_threads_size_--;
if (work_queue_.empty()) {
// Dynamic reduce thread number
if (shutdown_) {
current_threads_size_--;
return;
}
if (is_timeout) {
FinishThreads();
if (current_threads_size_ > min_threads_size_) {
need_finish_threads_.enqueue(std::this_thread::get_id());
current_threads_size_--;
return;
}
continue;
}
}
dequeue = work_queue_.dequeue(func);
lock.unlock();
if (dequeue) {
func();
}
}
}
}; // namespace milvus

View File

@ -36,10 +36,13 @@ class ThreadPool {
explicit ThreadPool(const int thread_core_coefficient,
const std::string& name)
: shutdown_(false), name_(name) {
auto thread_num = CPU_NUM * thread_core_coefficient;
threads_ = std::vector<std::thread>(thread_num);
idle_threads_size_ = 0;
current_threads_size_ = 0;
min_threads_size_ = CPU_NUM;
max_threads_size_ = CPU_NUM * thread_core_coefficient;
LOG_SEGCORE_INFO_ << "Init thread pool:" << name_
<< " with worker num:" << thread_num;
<< " with min worker num:" << min_threads_size_
<< " and max worker num:" << max_threads_size_;
Init();
}
@ -60,6 +63,12 @@ class ThreadPool {
void
ShutDown();
size_t
GetThreadNum() {
std::lock_guard<std::mutex> lock(mutex_);
return current_threads_size_;
}
template <typename F, typename... Args>
auto
// Submit(F&& f, Args&&... args) -> std::future<decltype(f(args...))>;
@ -73,15 +82,37 @@ class ThreadPool {
work_queue_.enqueue(wrap_func);
condition_lock_.notify_one();
std::lock_guard<std::mutex> lock(mutex_);
if (idle_threads_size_ > 0) {
condition_lock_.notify_one();
} else if (current_threads_size_ < max_threads_size_) {
// Dynamic increase thread number
std::thread t(&ThreadPool::Worker, this);
assert(threads_.find(t.get_id()) == threads_.end());
threads_[t.get_id()] = std::move(t);
current_threads_size_++;
}
return task_ptr->get_future();
}
void
Worker();
void
FinishThreads();
public:
int min_threads_size_;
int idle_threads_size_;
int current_threads_size_;
int max_threads_size_;
bool shutdown_;
static constexpr size_t WAIT_SECONDS = 2;
SafeQueue<std::function<void()>> work_queue_;
std::vector<std::thread> threads_;
std::unordered_map<std::thread::id, std::thread> threads_;
SafeQueue<std::thread::id> need_finish_threads_;
std::mutex mutex_;
std::condition_variable condition_lock_;
std::string name_;

View File

@ -106,8 +106,44 @@ test_worker(string s) {
return 1;
}
int
compute(int a) {
return a + 10;
}
TEST_F(DiskAnnFileManagerTest, TestThreadPoolBase) {
auto thread_pool = std::make_shared<milvus::ThreadPool>(10, "test1");
std::cout << "current thread num" << thread_pool->GetThreadNum()
<< std::endl;
auto thread_num_1 = thread_pool->GetThreadNum();
EXPECT_GT(thread_num_1, 0);
auto fut = thread_pool->Submit(compute, 10);
auto res = fut.get();
EXPECT_EQ(res, 20);
std::vector<std::future<int>> futs;
for (int i = 0; i < 10; ++i) {
futs.push_back(thread_pool->Submit(compute, i));
}
std::cout << "current thread num" << thread_pool->GetThreadNum()
<< std::endl;
auto thread_num_2 = thread_pool->GetThreadNum();
EXPECT_GT(thread_num_2, thread_num_1);
for (int i = 0; i < 10; ++i) {
std::cout << futs[i].get() << std::endl;
}
sleep(5);
std::cout << "current thread num" << thread_pool->GetThreadNum()
<< std::endl;
auto thread_num_3 = thread_pool->GetThreadNum();
EXPECT_LT(thread_num_3, thread_num_2);
}
TEST_F(DiskAnnFileManagerTest, TestThreadPool) {
auto thread_pool = new milvus::ThreadPool(50, "test");
auto thread_pool = std::make_shared<milvus::ThreadPool>(50, "test");
std::vector<std::future<int>> futures;
auto start = chrono::system_clock::now();
for (int i = 0; i < 100; i++) {
@ -121,6 +157,7 @@ TEST_F(DiskAnnFileManagerTest, TestThreadPool) {
auto duration = chrono::duration_cast<chrono::microseconds>(end - start);
auto second = double(duration.count()) * chrono::microseconds::period::num /
chrono::microseconds::period::den;
std::cout << "cost time:" << second << std::endl;
EXPECT_LT(second, 4 * 100);
}
@ -134,7 +171,7 @@ test_exception(string s) {
TEST_F(DiskAnnFileManagerTest, TestThreadPoolException) {
try {
auto thread_pool = new milvus::ThreadPool(50, "test");
auto thread_pool = std::make_shared<milvus::ThreadPool>(50, "test");
std::vector<std::future<int>> futures;
for (int i = 0; i < 100; i++) {
futures.push_back(thread_pool->Submit(