Search with big nq and big topk (#5748)

* fix #5747 #5115

Signed-off-by: yhmo <yihua.mo@zilliz.com>
pull/5762/head
groot 2021-06-15 09:53:12 +08:00 committed by GitHub
parent 73190282bd
commit a2ff18c40d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 49 additions and 22 deletions

View File

@ -25,7 +25,7 @@ namespace milvus {
namespace knowhere {
static const int64_t MIN_K = 0;
static const int64_t MAX_K = 16384;
static const int64_t MAX_K = 1024 * 1024;
static const int64_t MIN_NBITS = 1;
static const int64_t MAX_NBITS = 16;
static const int64_t DEFAULT_NBITS = 8;

View File

@ -71,7 +71,7 @@ Server::Start() {
std::string meta_uri;
STATUS_CHECK(config.GetGeneralConfigMetaURI(meta_uri));
if (meta_uri.length() > 6 && strcasecmp("sqlite", meta_uri.substr(0, 6).c_str()) == 0) {
std::cout << "NOTICE: You are using SQLite as the meta data management, "
std::cout << "NOTICE: You are using SQLite as the meta data management. "
"We recommend change it to MySQL."
<< std::endl;
}

View File

@ -73,6 +73,12 @@ SearchRequest::OnPreExecute() {
return status;
}
status = ValidationUtil::ValidateResultSize(vectors_data_.vector_count_, topk_);
if (!status.ok()) {
LOG_SERVER_ERROR_ << LogOut("[%s][%ld] %s", "search", 0, status.message().c_str());
return status;
}
// step 3: check partition tags
status = ValidationUtil::ValidatePartitionTags(partition_list_);
fiu_do_on("SearchRequest.OnExecute.invalid_partition_tags", status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));

View File

@ -47,7 +47,11 @@ constexpr int32_t INDEX_FILE_SIZE_LIMIT = 65536; // due to max size memory of f
constexpr int32_t INDEX_FILE_SIZE_LIMIT = 131072; // index trigger size max = 128G
#endif
constexpr int64_t M_BYTE = 1024 * 1024;
constexpr int64_t G_BYTE = M_BYTE * 1024;
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE;
// search result size limited by grpc message size
// consider the result struct contains members such as row_count/status, subtract 1MB
constexpr int64_t MAX_SEARCH_RESULT_SIZE = 2 * G_BYTE - M_BYTE;
Status
CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max,
@ -420,8 +424,22 @@ ValidationUtil::ValidateCollectionIndexMetricType(int32_t metric_type) {
Status
ValidationUtil::ValidateSearchTopk(int64_t top_k) {
if (top_k <= 0 || top_k > QUERY_MAX_TOPK) {
std::string msg =
"Invalid topk: " + std::to_string(top_k) + ". " + "The topk must be within the range of 1 ~ 16384.";
std::string msg = "Invalid topk: " + std::to_string(top_k) + ". " +
"The topk must be within the range of 1 ~ " + std::to_string(QUERY_MAX_TOPK) + ".";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_TOPK, msg);
}
return Status::OK();
}
Status
ValidationUtil::ValidateResultSize(int64_t vector_count, int64_t top_k) {
// each id-distance pair is 12 bytes (sizeof(int64) + sizeof(float))
int64_t result_size = vector_count * top_k * 12;
if (result_size >= MAX_SEARCH_RESULT_SIZE) {
std::string msg = "Invalid nq " + std::to_string(vector_count) + " topk " + std::to_string(top_k) +
". The search result size may exceed the RPC transmission limit.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_TOPK, msg);
}

View File

@ -22,7 +22,7 @@
namespace milvus {
namespace server {
constexpr int64_t QUERY_MAX_TOPK = 16384;
constexpr int64_t QUERY_MAX_TOPK = 1024 * 1024;
constexpr int64_t GPU_QUERY_MAX_TOPK = 16384;
constexpr int64_t GPU_QUERY_MAX_NPROBE = 2048;
@ -63,6 +63,9 @@ class ValidationUtil {
static Status
ValidateSearchTopk(int64_t top_k);
static Status
ValidateResultSize(int64_t vector_count, int64_t top_k);
static Status
ValidatePartitionName(const std::string& partition_name);

View File

@ -702,8 +702,14 @@ TEST(ValidationUtilTest, VALIDATE_VECTOR_DATA_TEST) {
TEST(ValidationUtilTest, VALIDATE_TOPK_TEST) {
ASSERT_EQ(milvus::server::ValidationUtil::ValidateSearchTopk(10).code(), milvus::SERVER_SUCCESS);
ASSERT_NE(milvus::server::ValidationUtil::ValidateSearchTopk(65536).code(), milvus::SERVER_SUCCESS);
ASSERT_NE(milvus::server::ValidationUtil::ValidateSearchTopk(0).code(), milvus::SERVER_SUCCESS);
int64_t max_topk = milvus::server::QUERY_MAX_TOPK;
ASSERT_EQ(milvus::server::ValidationUtil::ValidateSearchTopk(max_topk).code(), milvus::SERVER_SUCCESS);
ASSERT_NE(milvus::server::ValidationUtil::ValidateSearchTopk(max_topk + 1).code(), milvus::SERVER_SUCCESS);
int64_t count = 171; // this value is 2GB/QUERY_MAX_TOPK/12
ASSERT_EQ(milvus::server::ValidationUtil::ValidateResultSize(count - 1, max_topk).code(), milvus::SERVER_SUCCESS);
ASSERT_NE(milvus::server::ValidationUtil::ValidateResultSize(count, max_topk).code(), milvus::SERVER_SUCCESS);
}
TEST(ValidationUtilTest, VALIDATE_PARTITION_TAGS) {

View File

@ -147,15 +147,12 @@ class TestSearchBase:
query_vec = [vectors[0]]
top_k = get_top_k
status, result = connect.search(collection, top_k, query_vec)
if top_k <= 16384:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert result[0][0].distance <= epsilon
assert check_result(result[0], ids[0])
else:
assert not status.OK()
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert result[0][0].distance <= epsilon
assert check_result(result[0], ids[0])
@pytest.mark.level(2)
@pytest.mark.level(1)
def test_search_top_max_nq(self, connect, collection):
'''
target: test basic search fuction, assert fail if nq * topk is larger than max_value
@ -211,14 +208,11 @@ class TestSearchBase:
search_param = get_search_param(index_type)
status, result = connect.search(collection, top_k, query_vec, params=search_param)
logging.getLogger().info(result)
if top_k <= 1024:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance < result[0][1].distance
assert result[1][0].distance < result[1][1].distance
else:
assert not status.OK()
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance < result[0][1].distance
assert result[1][0].distance < result[1][1].distance
def test_search_l2_large_nq_index_params(self, connect, collection, get_simple_index):
'''