diff --git a/core/src/config/Config.cpp b/core/src/config/Config.cpp index 63688ae493..fb70c4b779 100644 --- a/core/src/config/Config.cpp +++ b/core/src/config/Config.cpp @@ -113,6 +113,8 @@ const char* CONFIG_ENGINE_OMP_THREAD_NUM = "omp_thread_num"; const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT = "0"; const char* CONFIG_ENGINE_SIMD_TYPE = "simd_type"; const char* CONFIG_ENGINE_SIMD_TYPE_DEFAULT = "auto"; +const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ = "search_combine_nq"; +const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT = "64"; /* gpu resource config */ const char* CONFIG_GPU_RESOURCE = "gpu"; @@ -198,6 +200,9 @@ Config::Config() { std::string node_blas_threshold = std::string(CONFIG_ENGINE) + "." + CONFIG_ENGINE_USE_BLAS_THRESHOLD; config_callback_[node_blas_threshold] = empty_map; + std::string node_search_combine = std::string(CONFIG_ENGINE) + "." + CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ; + config_callback_[node_search_combine] = empty_map; + // gpu resources config std::string node_gpu_enable = std::string(CONFIG_GPU_RESOURCE) + "." + CONFIG_GPU_RESOURCE_ENABLE; config_callback_[node_gpu_enable] = empty_map; @@ -451,6 +456,7 @@ Config::ResetDefaultConfig() { STATUS_CHECK(SetEngineConfigUseBlasThreshold(CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT)); STATUS_CHECK(SetEngineConfigOmpThreadNum(CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT)); STATUS_CHECK(SetEngineConfigSimdType(CONFIG_ENGINE_SIMD_TYPE_DEFAULT)); + STATUS_CHECK(SetEngineSearchCombineMaxNq(CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT)); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION @@ -578,6 +584,8 @@ Config::SetConfigCli(const std::string& parent_key, const std::string& child_key status = SetEngineConfigOmpThreadNum(value); } else if (child_key == CONFIG_ENGINE_SIMD_TYPE) { status = SetEngineConfigSimdType(value); + } else if (child_key == CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ) { + status = SetEngineSearchCombineMaxNq(value); } else { status = Status(SERVER_UNEXPECTED_ERROR, invalid_node_str); } @@ -1344,6 +1352,18 @@ Config::CheckEngineConfigSimdType(const std::string& value) { return Status::OK(); } +Status +Config::CheckEngineSearchCombineMaxNq(const std::string& value) { + fiu_return_on("check_config_search_combine_nq_fail", Status(SERVER_INVALID_ARGUMENT, "")); + + if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { + std::string msg = "Invalid omp thread num: " + value + + ". Possible reason: engine_config.omp_thread_num is not a positive integer."; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + return Status::OK(); +} + /* gpu resource config */ #ifdef MILVUS_GPU_VERSION Status @@ -1967,6 +1987,15 @@ Config::GetEngineConfigSimdType(std::string& value) { return CheckEngineConfigSimdType(value); } +Status +Config::GetEngineSearchCombineMaxNq(int64_t& value) { + std::string str = + GetConfigStr(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT); + // STATUS_CHECK(CheckEngineSearchCombineMaxNq(str)); + value = std::stoll(str); + return Status::OK(); +} + /* gpu resource config */ #ifdef MILVUS_GPU_VERSION Status @@ -2361,8 +2390,16 @@ Config::SetEngineConfigSimdType(const std::string& value) { return SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_SIMD_TYPE, value); } +Status +Config::SetEngineSearchCombineMaxNq(const std::string& value) { + STATUS_CHECK(CheckEngineSearchCombineMaxNq(value)); + STATUS_CHECK(SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, value)); + return ExecCallBacks(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, value); +} + /* gpu resource config */ #ifdef MILVUS_GPU_VERSION + Status Config::SetGpuResourceConfigEnable(const std::string& value) { STATUS_CHECK(CheckGpuResourceConfigEnable(value)); @@ -2407,6 +2444,7 @@ Config::SetGpuResourceConfigBuildIndexResources(const std::string& value) { STATUS_CHECK(SetConfigValueInMem(CONFIG_GPU_RESOURCE, CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES, value)); return ExecCallBacks(CONFIG_GPU_RESOURCE, CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES, value); } + #endif /* tracing config */ diff --git a/core/src/config/Config.h b/core/src/config/Config.h index 1819c0c1e1..563b2eb487 100644 --- a/core/src/config/Config.h +++ b/core/src/config/Config.h @@ -100,6 +100,8 @@ extern const char* CONFIG_ENGINE_OMP_THREAD_NUM; extern const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT; extern const char* CONFIG_ENGINE_SIMD_TYPE; extern const char* CONFIG_ENGINE_SIMD_TYPE_DEFAULT; +extern const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ; +extern const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT; /* gpu resource config */ extern const char* CONFIG_GPU_RESOURCE; @@ -264,6 +266,8 @@ class Config { CheckEngineConfigOmpThreadNum(const std::string& value); Status CheckEngineConfigSimdType(const std::string& value); + Status + CheckEngineSearchCombineMaxNq(const std::string& value); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION @@ -382,6 +386,8 @@ class Config { GetEngineConfigOmpThreadNum(int64_t& value); Status GetEngineConfigSimdType(std::string& value); + Status + GetEngineSearchCombineMaxNq(int64_t& value); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION @@ -492,6 +498,8 @@ class Config { SetEngineConfigOmpThreadNum(const std::string& value); Status SetEngineConfigSimdType(const std::string& value); + Status + SetEngineSearchCombineMaxNq(const std::string& value); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION diff --git a/core/src/config/handler/EngineConfigHandler.cpp b/core/src/config/handler/EngineConfigHandler.cpp index 51b08e17ba..e838bc0773 100644 --- a/core/src/config/handler/EngineConfigHandler.cpp +++ b/core/src/config/handler/EngineConfigHandler.cpp @@ -19,10 +19,12 @@ namespace server { EngineConfigHandler::EngineConfigHandler() { auto& config = Config::GetInstance(); config.GetEngineConfigUseBlasThreshold(use_blas_threshold_); + config.GetEngineSearchCombineMaxNq(search_combine_nq_); } EngineConfigHandler::~EngineConfigHandler() { RemoveUseBlasThresholdListener(); + RemoveSearchCombineMaxNqListener(); } //////////////////////////// Listener methods ////////////////////////////////// @@ -48,5 +50,27 @@ EngineConfigHandler::RemoveUseBlasThresholdListener() { config.CancelCallBack(CONFIG_ENGINE, CONFIG_ENGINE_USE_BLAS_THRESHOLD, identity_); } +void +EngineConfigHandler::AddSearchCombineMaxNqListener() { + ConfigCallBackF lambda = [this](const std::string& value) -> Status { + auto& config = server::Config::GetInstance(); + auto status = config.GetEngineSearchCombineMaxNq(search_combine_nq_); + if (status.ok()) { + OnSearchCombineMaxNqChanged(search_combine_nq_); + } + + return status; + }; + + auto& config = Config::GetInstance(); + config.RegisterCallBack(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, identity_, lambda); +} + +void +EngineConfigHandler::RemoveSearchCombineMaxNqListener() { + auto& config = Config::GetInstance(); + config.CancelCallBack(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, identity_); +} + } // namespace server } // namespace milvus diff --git a/core/src/config/handler/EngineConfigHandler.h b/core/src/config/handler/EngineConfigHandler.h index ebc055c7c8..3fed6c9847 100644 --- a/core/src/config/handler/EngineConfigHandler.h +++ b/core/src/config/handler/EngineConfigHandler.h @@ -28,16 +28,27 @@ class EngineConfigHandler : virtual public ConfigHandler { OnUseBlasThresholdChanged(int64_t threshold) { } + virtual void + OnSearchCombineMaxNqChanged(int64_t nq) { + search_combine_nq_ = nq; + } + protected: void AddUseBlasThresholdListener(); - protected: void RemoveUseBlasThresholdListener(); + void + AddSearchCombineMaxNqListener(); + + void + RemoveSearchCombineMaxNqListener(); + protected: int64_t use_blas_threshold_ = std::stoll(CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT); + int64_t search_combine_nq_ = std::stoll(CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT); }; } // namespace server diff --git a/core/src/server/delivery/request/SearchCombineRequest.cpp b/core/src/server/delivery/request/SearchCombineRequest.cpp index 6dca3dd96b..d53e3e6394 100644 --- a/core/src/server/delivery/request/SearchCombineRequest.cpp +++ b/core/src/server/delivery/request/SearchCombineRequest.cpp @@ -27,7 +27,6 @@ namespace server { namespace { constexpr int64_t MAX_TOPK_GAP = 200; -constexpr uint64_t MAX_NQ = 200; void GetUniqueList(const std::vector& list, std::set& unique_list) { @@ -93,7 +92,8 @@ class TracingContextList { } // namespace -SearchCombineRequest::SearchCombineRequest() : BaseRequest(nullptr, BaseRequest::kSearchCombine) { +SearchCombineRequest::SearchCombineRequest(int64_t max_nq) + : BaseRequest(nullptr, BaseRequest::kSearchCombine), combine_max_nq_(max_nq) { } Status @@ -133,6 +133,8 @@ SearchCombineRequest::Combine(const SearchRequestPtr& request) { } request_list_.push_back(request); + vectors_data_.vector_count_ += request->VectorsData().vector_count_; + return Status::OK(); } @@ -152,11 +154,11 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& request) { } // sum of nq must less-equal than MAX_NQ - if (vectors_data_.vector_count_ > MAX_NQ || request->VectorsData().vector_count_ > MAX_NQ) { + if (vectors_data_.vector_count_ > combine_max_nq_ || request->VectorsData().vector_count_ > combine_max_nq_) { return false; } uint64_t total_nq = vectors_data_.vector_count_ + request->VectorsData().vector_count_; - if (total_nq > MAX_NQ) { + if (total_nq > combine_max_nq_) { return false; } @@ -178,7 +180,7 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& request) { } bool -SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right) { +SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right, int64_t max_nq) { if (left->CollectionName() != right->CollectionName()) { return false; } @@ -193,11 +195,11 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchReque } // sum of nq must less-equal than MAX_NQ - if (left->VectorsData().vector_count_ > MAX_NQ || right->VectorsData().vector_count_ > MAX_NQ) { + if (left->VectorsData().vector_count_ > max_nq || right->VectorsData().vector_count_ > max_nq) { return false; } uint64_t total_nq = left->VectorsData().vector_count_ + right->VectorsData().vector_count_; - if (total_nq > MAX_NQ) { + if (total_nq > max_nq) { return false; } diff --git a/core/src/server/delivery/request/SearchCombineRequest.h b/core/src/server/delivery/request/SearchCombineRequest.h index 3aa24bb928..a455c130d9 100644 --- a/core/src/server/delivery/request/SearchCombineRequest.h +++ b/core/src/server/delivery/request/SearchCombineRequest.h @@ -22,9 +22,11 @@ namespace milvus { namespace server { +constexpr int64_t COMBINE_MAX_NQ = 64; + class SearchCombineRequest : public BaseRequest { public: - SearchCombineRequest(); + SearchCombineRequest(int64_t max_nq = COMBINE_MAX_NQ); Status Combine(const SearchRequestPtr& request); @@ -33,7 +35,7 @@ class SearchCombineRequest : public BaseRequest { CanCombine(const SearchRequestPtr& request); static bool - CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right); + CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right, int64_t max_nq = COMBINE_MAX_NQ); protected: Status @@ -54,6 +56,8 @@ class SearchCombineRequest : public BaseRequest { std::set file_id_list_; std::vector request_list_; + + int64_t combine_max_nq_ = COMBINE_MAX_NQ; }; using SearchCombineRequestPtr = std::shared_ptr; diff --git a/core/src/server/delivery/strategy/SearchReqStrategy.cpp b/core/src/server/delivery/strategy/SearchReqStrategy.cpp index 3b49ed6964..0b66ca7b57 100644 --- a/core/src/server/delivery/strategy/SearchReqStrategy.cpp +++ b/core/src/server/delivery/strategy/SearchReqStrategy.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "server/delivery/strategy/SearchReqStrategy.h" +#include "config/Config.h" #include "server/delivery/request/SearchCombineRequest.h" #include "server/delivery/request/SearchRequest.h" #include "utils/CommonUtil.h" @@ -24,6 +25,8 @@ namespace milvus { namespace server { SearchReqStrategy::SearchReqStrategy() { + SetIdentity("SearchReqStrategy"); + AddSearchCombineMaxNqListener(); } Status @@ -34,15 +37,21 @@ SearchReqStrategy::ReScheduleQueue(const BaseRequestPtr& request, std::queue(request); BaseRequestPtr last_req = queue.back(); if (last_req->GetRequestType() == BaseRequest::kSearch) { SearchRequestPtr last_search_req = std::static_pointer_cast(last_req); - if (SearchCombineRequest::CanCombine(last_search_req, new_search_req)) { + if (SearchCombineRequest::CanCombine(last_search_req, new_search_req, search_combine_nq_)) { // combine request - SearchCombineRequestPtr combine_request = std::make_shared(); + SearchCombineRequestPtr combine_request = std::make_shared(search_combine_nq_); combine_request->Combine(last_search_req); combine_request->Combine(new_search_req); queue.back() = combine_request; // replace the last request to combine request diff --git a/core/src/server/delivery/strategy/SearchReqStrategy.h b/core/src/server/delivery/strategy/SearchReqStrategy.h index 20093c66c2..3d2c3de03b 100644 --- a/core/src/server/delivery/strategy/SearchReqStrategy.h +++ b/core/src/server/delivery/strategy/SearchReqStrategy.h @@ -11,6 +11,7 @@ #pragma once +#include "config/handler/EngineConfigHandler.h" #include "server/delivery/strategy/RequestStrategy.h" #include "utils/Status.h" @@ -20,7 +21,7 @@ namespace milvus { namespace server { -class SearchReqStrategy : public RequestStrategy { +class SearchReqStrategy : public RequestStrategy, public EngineConfigHandler { public: SearchReqStrategy();