mirror of https://github.com/milvus-io/milvus.git
Search with big nq and big topk (#5748)
* fix #5747 #5115 Signed-off-by: yhmo <yihua.mo@zilliz.com>pull/5762/head
parent
73190282bd
commit
a2ff18c40d
|
@ -25,7 +25,7 @@ namespace milvus {
|
|||
namespace knowhere {
|
||||
|
||||
static const int64_t MIN_K = 0;
|
||||
static const int64_t MAX_K = 16384;
|
||||
static const int64_t MAX_K = 1024 * 1024;
|
||||
static const int64_t MIN_NBITS = 1;
|
||||
static const int64_t MAX_NBITS = 16;
|
||||
static const int64_t DEFAULT_NBITS = 8;
|
||||
|
|
|
@ -71,7 +71,7 @@ Server::Start() {
|
|||
std::string meta_uri;
|
||||
STATUS_CHECK(config.GetGeneralConfigMetaURI(meta_uri));
|
||||
if (meta_uri.length() > 6 && strcasecmp("sqlite", meta_uri.substr(0, 6).c_str()) == 0) {
|
||||
std::cout << "NOTICE: You are using SQLite as the meta data management, "
|
||||
std::cout << "NOTICE: You are using SQLite as the meta data management. "
|
||||
"We recommend change it to MySQL."
|
||||
<< std::endl;
|
||||
}
|
||||
|
|
|
@ -73,6 +73,12 @@ SearchRequest::OnPreExecute() {
|
|||
return status;
|
||||
}
|
||||
|
||||
status = ValidationUtil::ValidateResultSize(vectors_data_.vector_count_, topk_);
|
||||
if (!status.ok()) {
|
||||
LOG_SERVER_ERROR_ << LogOut("[%s][%ld] %s", "search", 0, status.message().c_str());
|
||||
return status;
|
||||
}
|
||||
|
||||
// step 3: check partition tags
|
||||
status = ValidationUtil::ValidatePartitionTags(partition_list_);
|
||||
fiu_do_on("SearchRequest.OnExecute.invalid_partition_tags", status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));
|
||||
|
|
|
@ -47,7 +47,11 @@ constexpr int32_t INDEX_FILE_SIZE_LIMIT = 65536; // due to max size memory of f
|
|||
constexpr int32_t INDEX_FILE_SIZE_LIMIT = 131072; // index trigger size max = 128G
|
||||
#endif
|
||||
constexpr int64_t M_BYTE = 1024 * 1024;
|
||||
constexpr int64_t G_BYTE = M_BYTE * 1024;
|
||||
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE;
|
||||
// search result size limited by grpc message size
|
||||
// consider the result struct contains members such as row_count/status, subtract 1MB
|
||||
constexpr int64_t MAX_SEARCH_RESULT_SIZE = 2 * G_BYTE - M_BYTE;
|
||||
|
||||
Status
|
||||
CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max,
|
||||
|
@ -420,8 +424,22 @@ ValidationUtil::ValidateCollectionIndexMetricType(int32_t metric_type) {
|
|||
Status
|
||||
ValidationUtil::ValidateSearchTopk(int64_t top_k) {
|
||||
if (top_k <= 0 || top_k > QUERY_MAX_TOPK) {
|
||||
std::string msg =
|
||||
"Invalid topk: " + std::to_string(top_k) + ". " + "The topk must be within the range of 1 ~ 16384.";
|
||||
std::string msg = "Invalid topk: " + std::to_string(top_k) + ". " +
|
||||
"The topk must be within the range of 1 ~ " + std::to_string(QUERY_MAX_TOPK) + ".";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_TOPK, msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ValidationUtil::ValidateResultSize(int64_t vector_count, int64_t top_k) {
|
||||
// each id-distance pair is 12 bytes (sizeof(int64) + sizeof(float))
|
||||
int64_t result_size = vector_count * top_k * 12;
|
||||
if (result_size >= MAX_SEARCH_RESULT_SIZE) {
|
||||
std::string msg = "Invalid nq " + std::to_string(vector_count) + " topk " + std::to_string(top_k) +
|
||||
". The search result size may exceed the RPC transmission limit.";
|
||||
LOG_SERVER_ERROR_ << msg;
|
||||
return Status(SERVER_INVALID_TOPK, msg);
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
namespace milvus {
|
||||
namespace server {
|
||||
|
||||
constexpr int64_t QUERY_MAX_TOPK = 16384;
|
||||
constexpr int64_t QUERY_MAX_TOPK = 1024 * 1024;
|
||||
constexpr int64_t GPU_QUERY_MAX_TOPK = 16384;
|
||||
constexpr int64_t GPU_QUERY_MAX_NPROBE = 2048;
|
||||
|
||||
|
@ -63,6 +63,9 @@ class ValidationUtil {
|
|||
static Status
|
||||
ValidateSearchTopk(int64_t top_k);
|
||||
|
||||
static Status
|
||||
ValidateResultSize(int64_t vector_count, int64_t top_k);
|
||||
|
||||
static Status
|
||||
ValidatePartitionName(const std::string& partition_name);
|
||||
|
||||
|
|
|
@ -702,8 +702,14 @@ TEST(ValidationUtilTest, VALIDATE_VECTOR_DATA_TEST) {
|
|||
|
||||
TEST(ValidationUtilTest, VALIDATE_TOPK_TEST) {
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateSearchTopk(10).code(), milvus::SERVER_SUCCESS);
|
||||
ASSERT_NE(milvus::server::ValidationUtil::ValidateSearchTopk(65536).code(), milvus::SERVER_SUCCESS);
|
||||
ASSERT_NE(milvus::server::ValidationUtil::ValidateSearchTopk(0).code(), milvus::SERVER_SUCCESS);
|
||||
int64_t max_topk = milvus::server::QUERY_MAX_TOPK;
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateSearchTopk(max_topk).code(), milvus::SERVER_SUCCESS);
|
||||
ASSERT_NE(milvus::server::ValidationUtil::ValidateSearchTopk(max_topk + 1).code(), milvus::SERVER_SUCCESS);
|
||||
|
||||
int64_t count = 171; // this value is 2GB/QUERY_MAX_TOPK/12
|
||||
ASSERT_EQ(milvus::server::ValidationUtil::ValidateResultSize(count - 1, max_topk).code(), milvus::SERVER_SUCCESS);
|
||||
ASSERT_NE(milvus::server::ValidationUtil::ValidateResultSize(count, max_topk).code(), milvus::SERVER_SUCCESS);
|
||||
}
|
||||
|
||||
TEST(ValidationUtilTest, VALIDATE_PARTITION_TAGS) {
|
||||
|
|
|
@ -147,15 +147,12 @@ class TestSearchBase:
|
|||
query_vec = [vectors[0]]
|
||||
top_k = get_top_k
|
||||
status, result = connect.search(collection, top_k, query_vec)
|
||||
if top_k <= 16384:
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert result[0][0].distance <= epsilon
|
||||
assert check_result(result[0], ids[0])
|
||||
else:
|
||||
assert not status.OK()
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert result[0][0].distance <= epsilon
|
||||
assert check_result(result[0], ids[0])
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.level(1)
|
||||
def test_search_top_max_nq(self, connect, collection):
|
||||
'''
|
||||
target: test basic search fuction, assert fail if nq * topk is larger than max_value
|
||||
|
@ -211,14 +208,11 @@ class TestSearchBase:
|
|||
search_param = get_search_param(index_type)
|
||||
status, result = connect.search(collection, top_k, query_vec, params=search_param)
|
||||
logging.getLogger().info(result)
|
||||
if top_k <= 1024:
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert check_result(result[0], ids[0])
|
||||
assert result[0][0].distance < result[0][1].distance
|
||||
assert result[1][0].distance < result[1][1].distance
|
||||
else:
|
||||
assert not status.OK()
|
||||
assert status.OK()
|
||||
assert len(result[0]) == min(len(vectors), top_k)
|
||||
assert check_result(result[0], ids[0])
|
||||
assert result[0][0].distance < result[0][1].distance
|
||||
assert result[1][0].distance < result[1][1].distance
|
||||
|
||||
def test_search_l2_large_nq_index_params(self, connect, collection, get_simple_index):
|
||||
'''
|
||||
|
|
Loading…
Reference in New Issue