From 0b25f4c0152ec0d50917a15b95d023db3ce0677f Mon Sep 17 00:00:00 2001 From: yukun Date: Tue, 11 Aug 2020 17:16:54 +0800 Subject: [PATCH] Fix CreateIndex bug in C++ sdk (#3216) * Fix TestSearchDSL level 2 bugs Signed-off-by: fishpenguin * Fix QueryTest Signed-off-by: fishpenguin * Add annotation in milvus.proto Signed-off-by: fishpenguin * Fix CreateIndex in C++ sdk Signed-off-by: fishpenguin * Fix C++ sdk range test Signed-off-by: fishpenguin * Fix test_search_ip_index_partitions Signed-off-by: fishpenguin Co-authored-by: quicksilver --- core/src/grpc/milvus.proto | 2 +- core/src/scheduler/task/SearchTask.cpp | 6 ++- sdk/examples/simple/src/ClientTest.cpp | 8 ++-- sdk/examples/utils/Utils.cpp | 38 +++++++++---------- sdk/grpc/ClientProxy.cpp | 14 +++++-- sdk/include/MilvusApi.h | 3 +- .../milvus_python_test/entity/test_search.py | 10 ++--- 7 files changed, 45 insertions(+), 36 deletions(-) diff --git a/core/src/grpc/milvus.proto b/core/src/grpc/milvus.proto index ef5bf40d7e..eec68d486b 100644 --- a/core/src/grpc/milvus.proto +++ b/core/src/grpc/milvus.proto @@ -202,9 +202,9 @@ message VectorParam { * "vector": { * "face_img": { * "topk": 10, + * "metric_type": "L2", * "query": [], * "params": { - * "metric_type": "L2", * "nprobe": 10 * } * } diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index 0ea75ccaa8..152f8d0f2d 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -88,7 +88,8 @@ SearchTask::OnLoad(LoadType type, uint8_t device_id) { s = Status(SERVER_UNEXPECTED_ERROR, error_msg); } - return s; + job_->status() = s; + return Status::OK(); } std::string info = "Search task load segment id: " + std::to_string(segment_id_) + " " + type_str + " totally cost"; @@ -133,6 +134,9 @@ SearchTask::OnExecute() { search_job->query_result() = std::make_shared(); search_job->query_result()->row_num_ = nq; } + if (vector_param->metric_type == "IP") { + ascending_reduce_ = false; + } SearchTask::MergeTopkToResultSet(context.query_result_->result_ids_, context.query_result_->result_distances_, spec_k, nq, topk, ascending_reduce_, search_job->query_result()); diff --git a/sdk/examples/simple/src/ClientTest.cpp b/sdk/examples/simple/src/ClientTest.cpp index 1944b9cfad..e1dc892dd4 100644 --- a/sdk/examples/simple/src/ClientTest.cpp +++ b/sdk/examples/simple/src/ClientTest.cpp @@ -202,7 +202,8 @@ ClientTest::GetEntityByID(const std::string& collection_name, const std::vector< } void -ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe, const std::string metric_type) { +ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe, + const std::string metric_type) { nlohmann::json dsl_json, vector_param_json; milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, metric_type); @@ -262,8 +263,8 @@ void ClientTest::CreateIndex(const std::string& collection_name, int64_t nlist) { milvus_sdk::TimeRecorder rc("Create index"); std::cout << "Wait until create all index done" << std::endl; - JSON json_params = {{"nlist", nlist}, {"index_type", "IVF_FLAT"}}; - milvus::IndexParam index1 = {collection_name, "field_vec", "index_3", json_params.dump()}; + JSON json_params = {{"index_type", "IVF_FLAT"}, {"metric_type", "L2"}, {"params", {{"nlist", nlist}}}}; + milvus::IndexParam index1 = {collection_name, "field_vec", json_params.dump()}; milvus_sdk::Utils::PrintIndexParam(index1); milvus::Status stat = conn_->CreateIndex(index1); std::cout << "CreateIndex function call status: " << stat.message() << std::endl; @@ -334,6 +335,7 @@ ClientTest::Test() { InsertEntities(collection_name); Flush(collection_name); CountEntities(collection_name); + CreateIndex(collection_name, 1024); // GetCollectionStats(collection_name); // BuildVectors(NQ, COLLECTION_DIMENSION); diff --git a/sdk/examples/utils/Utils.cpp b/sdk/examples/utils/Utils.cpp index ac16baf076..27b0d02ea1 100644 --- a/sdk/examples/utils/Utils.cpp +++ b/sdk/examples/utils/Utils.cpp @@ -154,8 +154,7 @@ Utils::PrintIndexParam(const milvus::IndexParam& index_param) { BLOCK_SPLITER std::cout << "Index collection name: " << index_param.collection_name << std::endl; std::cout << "Index field name: " << index_param.field_name << std::endl; - std::cout << "Index name: " << index_param.index_name << std::endl; - std::cout << "Index extra_params: " << index_param.extra_params << std::endl; + std::cout << "Index extra_params: " << index_param.index_params << std::endl; BLOCK_SPLITER } @@ -261,26 +260,26 @@ Utils::DoSearch(std::shared_ptr conn, const std::string& col const std::vector& partition_tags, int64_t top_k, int64_t nprobe, std::vector> entity_array, milvus::TopKQueryResult& topk_query_result) { -/* - topk_query_result.clear(); + /* + topk_query_result.clear(); - nlohmann::json dsl_json, vector_param_json; - GenDSLJson(dsl_json, vector_param_json); + nlohmann::json dsl_json, vector_param_json; + GenDSLJson(dsl_json, vector_param_json); - std::vector temp_entity_array; - for (auto& pair : entity_array) { - temp_entity_array.push_back(pair.second); - } - milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array}; + std::vector temp_entity_array; + for (auto& pair : entity_array) { + temp_entity_array.push_back(pair.second); + } + milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array}; - JSON json_params = {{"nprobe", nprobe}}; - milvus_sdk::TimeRecorder rc("Search"); + JSON json_params = {{"nprobe", nprobe}}; + milvus_sdk::TimeRecorder rc("Search"); - auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result); + auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result); - PrintTopKQueryResult(topk_query_result); - // PrintSearchResult(entity_array, topk_query_result); -*/ + PrintTopKQueryResult(topk_query_result); + // PrintSearchResult(entity_array, topk_query_result); + */ } void @@ -378,8 +377,8 @@ Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, c bool_json["must"].push_back(term_json); nlohmann::json comp_json; - comp_json["GTE"] = "0"; - comp_json["LTE"] = "100000"; + comp_json["GT"] = 0; + comp_json["LT"] = 100000; range_json["range"]["field_1"] = comp_json; bool_json["must"].push_back(range_json); @@ -438,4 +437,3 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) { } } // namespace milvus_sdk - diff --git a/sdk/grpc/ClientProxy.cpp b/sdk/grpc/ClientProxy.cpp index b83b3c0e30..351ed0ccdb 100644 --- a/sdk/grpc/ClientProxy.cpp +++ b/sdk/grpc/ClientProxy.cpp @@ -626,10 +626,16 @@ ClientProxy::CreateIndex(const IndexParam& index_param) { ::milvus::grpc::IndexParam grpc_index_param; grpc_index_param.set_collection_name(index_param.collection_name); grpc_index_param.set_field_name(index_param.field_name); - milvus::grpc::KeyValuePair* kv = grpc_index_param.add_extra_params(); - grpc_index_param.set_index_name(index_param.index_name); - kv->set_key(EXTRA_PARAM_KEY); - kv->set_value(index_param.extra_params); + JSON json_param = JSON::parse(index_param.index_params); + for (auto& item : json_param.items()) { + milvus::grpc::KeyValuePair* kv = grpc_index_param.add_extra_params(); + kv->set_key(item.key()); + if (item.value().is_object()) { + kv->set_value(item.value().dump()); + } else { + kv->set_value(item.value()); + } + } return client_ptr_->CreateIndex(grpc_index_param); } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "Failed to build index: " + std::string(ex.what())); diff --git a/sdk/include/MilvusApi.h b/sdk/include/MilvusApi.h index bc0e4a2498..affca34fe8 100644 --- a/sdk/include/MilvusApi.h +++ b/sdk/include/MilvusApi.h @@ -124,8 +124,7 @@ using TopKQueryResult = std::vector; ///< Topk hybrid query result struct IndexParam { std::string collection_name; ///< Collection name for create index std::string field_name; ///< Field name - std::string index_name; ///< Index name - std::string extra_params; ///< Extra parameters according to different index type, must be json format + std::string index_params; ///< Extra parameters according to different index type, must be json format }; /** diff --git a/tests/milvus_python_test/entity/test_search.py b/tests/milvus_python_test/entity/test_search.py index 788ddb8317..e28ddc70be 100644 --- a/tests/milvus_python_test/entity/test_search.py +++ b/tests/milvus_python_test/entity/test_search.py @@ -443,7 +443,7 @@ class TestSearchBase: assert len(res) == nq @pytest.mark.level(2) - def _test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): + def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): ''' target: test basic search fuction, all the search params is corrent, test all index params, and build method: search collection with the given vectors and tags, check the result @@ -919,7 +919,7 @@ class TestSearchDSL(object): # TODO: @pytest.mark.level(2) - def _test_query_term_value_all_in(self, connect, collection): + def test_query_term_value_all_in(self, connect, collection): ''' method: build query with vector and term expr, with all term can be filtered expected: filter pass @@ -934,7 +934,7 @@ class TestSearchDSL(object): # TODO: @pytest.mark.level(2) - def _test_query_term_values_not_in(self, connect, collection): + def test_query_term_values_not_in(self, connect, collection): ''' method: build query with vector and term expr, with no term can be filtered expected: filter pass @@ -977,7 +977,7 @@ class TestSearchDSL(object): # TODO: @pytest.mark.level(2) - def _test_query_term_values_repeat(self, connect, collection): + def test_query_term_values_repeat(self, connect, collection): ''' method: build query with vector and term expr, with the same values expected: filter pass @@ -1030,7 +1030,7 @@ class TestSearchDSL(object): # TODO @pytest.mark.level(2) - def test_query_term_wrong_format(self, connect, collection, get_invalid_term): + def _test_query_term_wrong_format(self, connect, collection, get_invalid_term): ''' method: build query with wrong format term expected: Exception raised