mirror of https://github.com/milvus-io/milvus.git
Refactor thread pool to dynamic threads mode (#26407)
Signed-off-by: luzhang <luzhang@zilliz.com> Co-authored-by: luzhang <luzhang@zilliz.com>pull/26381/head
parent
50a77ef1f7
commit
bdc8c507ea
|
@ -20,21 +20,80 @@ namespace milvus {
|
|||
|
||||
void
|
||||
ThreadPool::Init() {
|
||||
for (int i = 0; i < threads_.size(); i++) {
|
||||
threads_[i] = std::thread(Worker(this, i));
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (int i = 0; i < min_threads_size_; i++) {
|
||||
std::thread t(&ThreadPool::Worker, this);
|
||||
assert(threads_.find(t.get_id()) == threads_.end());
|
||||
threads_[t.get_id()] = std::move(t);
|
||||
current_threads_size_++;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::ShutDown() {
|
||||
LOG_SEGCORE_INFO_ << "Start shutting down " << name_;
|
||||
shutdown_ = true;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
shutdown_ = true;
|
||||
}
|
||||
condition_lock_.notify_all();
|
||||
for (int i = 0; i < threads_.size(); i++) {
|
||||
if (threads_[i].joinable()) {
|
||||
threads_[i].join();
|
||||
for (auto iter = threads_.begin(); iter != threads_.end(); ++iter) {
|
||||
if (iter->second.joinable()) {
|
||||
iter->second.join();
|
||||
}
|
||||
}
|
||||
LOG_SEGCORE_INFO_ << "Finish shutting down " << name_;
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::FinishThreads() {
|
||||
while (!need_finish_threads_.empty()) {
|
||||
std::thread::id id;
|
||||
auto dequeue = need_finish_threads_.dequeue(id);
|
||||
if (dequeue) {
|
||||
auto iter = threads_.find(id);
|
||||
assert(iter != threads_.end());
|
||||
if (iter->second.joinable()) {
|
||||
iter->second.join();
|
||||
}
|
||||
threads_.erase(iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ThreadPool::Worker() {
|
||||
std::function<void()> func;
|
||||
bool dequeue;
|
||||
while (!shutdown_) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
idle_threads_size_++;
|
||||
auto is_timeout = !condition_lock_.wait_for(
|
||||
lock, std::chrono::seconds(WAIT_SECONDS), [this]() {
|
||||
return shutdown_ || !work_queue_.empty();
|
||||
});
|
||||
idle_threads_size_--;
|
||||
if (work_queue_.empty()) {
|
||||
// Dynamic reduce thread number
|
||||
if (shutdown_) {
|
||||
current_threads_size_--;
|
||||
return;
|
||||
}
|
||||
if (is_timeout) {
|
||||
FinishThreads();
|
||||
if (current_threads_size_ > min_threads_size_) {
|
||||
need_finish_threads_.enqueue(std::this_thread::get_id());
|
||||
current_threads_size_--;
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
dequeue = work_queue_.dequeue(func);
|
||||
lock.unlock();
|
||||
if (dequeue) {
|
||||
func();
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace milvus
|
||||
|
|
|
@ -36,10 +36,13 @@ class ThreadPool {
|
|||
explicit ThreadPool(const int thread_core_coefficient,
|
||||
const std::string& name)
|
||||
: shutdown_(false), name_(name) {
|
||||
auto thread_num = CPU_NUM * thread_core_coefficient;
|
||||
threads_ = std::vector<std::thread>(thread_num);
|
||||
idle_threads_size_ = 0;
|
||||
current_threads_size_ = 0;
|
||||
min_threads_size_ = CPU_NUM;
|
||||
max_threads_size_ = CPU_NUM * thread_core_coefficient;
|
||||
LOG_SEGCORE_INFO_ << "Init thread pool:" << name_
|
||||
<< " with worker num:" << thread_num;
|
||||
<< " with min worker num:" << min_threads_size_
|
||||
<< " and max worker num:" << max_threads_size_;
|
||||
Init();
|
||||
}
|
||||
|
||||
|
@ -60,6 +63,12 @@ class ThreadPool {
|
|||
void
|
||||
ShutDown();
|
||||
|
||||
size_t
|
||||
GetThreadNum() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return current_threads_size_;
|
||||
}
|
||||
|
||||
template <typename F, typename... Args>
|
||||
auto
|
||||
// Submit(F&& f, Args&&... args) -> std::future<decltype(f(args...))>;
|
||||
|
@ -73,15 +82,37 @@ class ThreadPool {
|
|||
|
||||
work_queue_.enqueue(wrap_func);
|
||||
|
||||
condition_lock_.notify_one();
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
if (idle_threads_size_ > 0) {
|
||||
condition_lock_.notify_one();
|
||||
} else if (current_threads_size_ < max_threads_size_) {
|
||||
// Dynamic increase thread number
|
||||
std::thread t(&ThreadPool::Worker, this);
|
||||
assert(threads_.find(t.get_id()) == threads_.end());
|
||||
threads_[t.get_id()] = std::move(t);
|
||||
current_threads_size_++;
|
||||
}
|
||||
|
||||
return task_ptr->get_future();
|
||||
}
|
||||
|
||||
void
|
||||
Worker();
|
||||
|
||||
void
|
||||
FinishThreads();
|
||||
|
||||
public:
|
||||
int min_threads_size_;
|
||||
int idle_threads_size_;
|
||||
int current_threads_size_;
|
||||
int max_threads_size_;
|
||||
bool shutdown_;
|
||||
static constexpr size_t WAIT_SECONDS = 2;
|
||||
SafeQueue<std::function<void()>> work_queue_;
|
||||
std::vector<std::thread> threads_;
|
||||
std::unordered_map<std::thread::id, std::thread> threads_;
|
||||
SafeQueue<std::thread::id> need_finish_threads_;
|
||||
std::mutex mutex_;
|
||||
std::condition_variable condition_lock_;
|
||||
std::string name_;
|
||||
|
|
|
@ -106,8 +106,44 @@ test_worker(string s) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int
|
||||
compute(int a) {
|
||||
return a + 10;
|
||||
}
|
||||
|
||||
TEST_F(DiskAnnFileManagerTest, TestThreadPoolBase) {
|
||||
auto thread_pool = std::make_shared<milvus::ThreadPool>(10, "test1");
|
||||
std::cout << "current thread num" << thread_pool->GetThreadNum()
|
||||
<< std::endl;
|
||||
auto thread_num_1 = thread_pool->GetThreadNum();
|
||||
EXPECT_GT(thread_num_1, 0);
|
||||
|
||||
auto fut = thread_pool->Submit(compute, 10);
|
||||
auto res = fut.get();
|
||||
EXPECT_EQ(res, 20);
|
||||
|
||||
std::vector<std::future<int>> futs;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
futs.push_back(thread_pool->Submit(compute, i));
|
||||
}
|
||||
std::cout << "current thread num" << thread_pool->GetThreadNum()
|
||||
<< std::endl;
|
||||
auto thread_num_2 = thread_pool->GetThreadNum();
|
||||
EXPECT_GT(thread_num_2, thread_num_1);
|
||||
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
std::cout << futs[i].get() << std::endl;
|
||||
}
|
||||
|
||||
sleep(5);
|
||||
std::cout << "current thread num" << thread_pool->GetThreadNum()
|
||||
<< std::endl;
|
||||
auto thread_num_3 = thread_pool->GetThreadNum();
|
||||
EXPECT_LT(thread_num_3, thread_num_2);
|
||||
}
|
||||
|
||||
TEST_F(DiskAnnFileManagerTest, TestThreadPool) {
|
||||
auto thread_pool = new milvus::ThreadPool(50, "test");
|
||||
auto thread_pool = std::make_shared<milvus::ThreadPool>(50, "test");
|
||||
std::vector<std::future<int>> futures;
|
||||
auto start = chrono::system_clock::now();
|
||||
for (int i = 0; i < 100; i++) {
|
||||
|
@ -121,6 +157,7 @@ TEST_F(DiskAnnFileManagerTest, TestThreadPool) {
|
|||
auto duration = chrono::duration_cast<chrono::microseconds>(end - start);
|
||||
auto second = double(duration.count()) * chrono::microseconds::period::num /
|
||||
chrono::microseconds::period::den;
|
||||
std::cout << "cost time:" << second << std::endl;
|
||||
EXPECT_LT(second, 4 * 100);
|
||||
}
|
||||
|
||||
|
@ -134,7 +171,7 @@ test_exception(string s) {
|
|||
|
||||
TEST_F(DiskAnnFileManagerTest, TestThreadPoolException) {
|
||||
try {
|
||||
auto thread_pool = new milvus::ThreadPool(50, "test");
|
||||
auto thread_pool = std::make_shared<milvus::ThreadPool>(50, "test");
|
||||
std::vector<std::future<int>> futures;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
futures.push_back(thread_pool->Submit(
|
||||
|
|
Loading…
Reference in New Issue