Fix CreateIndex bug in C++ sdk (#3216)

* Fix TestSearchDSL level 2 bugs

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix QueryTest

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Add annotation in milvus.proto

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix CreateIndex in C++ sdk

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix C++ sdk range test

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix test_search_ip_index_partitions

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

Co-authored-by: quicksilver <zhifeng.zhang@zilliz.com>
pull/3212/head
yukun 2020-08-11 17:16:54 +08:00 committed by GitHub
parent 6b3402434b
commit 0b25f4c015
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 45 additions and 36 deletions

View File

@ -202,9 +202,9 @@ message VectorParam {
* "vector": {
* "face_img": {
* "topk": 10,
* "metric_type": "L2",
* "query": [],
* "params": {
* "metric_type": "L2",
* "nprobe": 10
* }
* }

View File

@ -88,7 +88,8 @@ SearchTask::OnLoad(LoadType type, uint8_t device_id) {
s = Status(SERVER_UNEXPECTED_ERROR, error_msg);
}
return s;
job_->status() = s;
return Status::OK();
}
std::string info = "Search task load segment id: " + std::to_string(segment_id_) + " " + type_str + " totally cost";
@ -133,6 +134,9 @@ SearchTask::OnExecute() {
search_job->query_result() = std::make_shared<engine::QueryResult>();
search_job->query_result()->row_num_ = nq;
}
if (vector_param->metric_type == "IP") {
ascending_reduce_ = false;
}
SearchTask::MergeTopkToResultSet(context.query_result_->result_ids_,
context.query_result_->result_distances_, spec_k, nq, topk,
ascending_reduce_, search_job->query_result());

View File

@ -202,7 +202,8 @@ ClientTest::GetEntityByID(const std::string& collection_name, const std::vector<
}
void
ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe, const std::string metric_type) {
ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe,
const std::string metric_type) {
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, metric_type);
@ -262,8 +263,8 @@ void
ClientTest::CreateIndex(const std::string& collection_name, int64_t nlist) {
milvus_sdk::TimeRecorder rc("Create index");
std::cout << "Wait until create all index done" << std::endl;
JSON json_params = {{"nlist", nlist}, {"index_type", "IVF_FLAT"}};
milvus::IndexParam index1 = {collection_name, "field_vec", "index_3", json_params.dump()};
JSON json_params = {{"index_type", "IVF_FLAT"}, {"metric_type", "L2"}, {"params", {{"nlist", nlist}}}};
milvus::IndexParam index1 = {collection_name, "field_vec", json_params.dump()};
milvus_sdk::Utils::PrintIndexParam(index1);
milvus::Status stat = conn_->CreateIndex(index1);
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
@ -334,6 +335,7 @@ ClientTest::Test() {
InsertEntities(collection_name);
Flush(collection_name);
CountEntities(collection_name);
CreateIndex(collection_name, 1024);
// GetCollectionStats(collection_name);
//
BuildVectors(NQ, COLLECTION_DIMENSION);

View File

@ -154,8 +154,7 @@ Utils::PrintIndexParam(const milvus::IndexParam& index_param) {
BLOCK_SPLITER
std::cout << "Index collection name: " << index_param.collection_name << std::endl;
std::cout << "Index field name: " << index_param.field_name << std::endl;
std::cout << "Index name: " << index_param.index_name << std::endl;
std::cout << "Index extra_params: " << index_param.extra_params << std::endl;
std::cout << "Index extra_params: " << index_param.index_params << std::endl;
BLOCK_SPLITER
}
@ -261,26 +260,26 @@ Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& col
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
std::vector<std::pair<int64_t, milvus::VectorData>> entity_array,
milvus::TopKQueryResult& topk_query_result) {
/*
topk_query_result.clear();
/*
topk_query_result.clear();
nlohmann::json dsl_json, vector_param_json;
GenDSLJson(dsl_json, vector_param_json);
nlohmann::json dsl_json, vector_param_json;
GenDSLJson(dsl_json, vector_param_json);
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : entity_array) {
temp_entity_array.push_back(pair.second);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : entity_array) {
temp_entity_array.push_back(pair.second);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
JSON json_params = {{"nprobe", nprobe}};
milvus_sdk::TimeRecorder rc("Search");
JSON json_params = {{"nprobe", nprobe}};
milvus_sdk::TimeRecorder rc("Search");
auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
PrintTopKQueryResult(topk_query_result);
// PrintSearchResult(entity_array, topk_query_result);
*/
PrintTopKQueryResult(topk_query_result);
// PrintSearchResult(entity_array, topk_query_result);
*/
}
void
@ -378,8 +377,8 @@ Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, c
bool_json["must"].push_back(term_json);
nlohmann::json comp_json;
comp_json["GTE"] = "0";
comp_json["LTE"] = "100000";
comp_json["GT"] = 0;
comp_json["LT"] = 100000;
range_json["range"]["field_1"] = comp_json;
bool_json["must"].push_back(range_json);
@ -438,4 +437,3 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
}
} // namespace milvus_sdk

View File

@ -626,10 +626,16 @@ ClientProxy::CreateIndex(const IndexParam& index_param) {
::milvus::grpc::IndexParam grpc_index_param;
grpc_index_param.set_collection_name(index_param.collection_name);
grpc_index_param.set_field_name(index_param.field_name);
milvus::grpc::KeyValuePair* kv = grpc_index_param.add_extra_params();
grpc_index_param.set_index_name(index_param.index_name);
kv->set_key(EXTRA_PARAM_KEY);
kv->set_value(index_param.extra_params);
JSON json_param = JSON::parse(index_param.index_params);
for (auto& item : json_param.items()) {
milvus::grpc::KeyValuePair* kv = grpc_index_param.add_extra_params();
kv->set_key(item.key());
if (item.value().is_object()) {
kv->set_value(item.value().dump());
} else {
kv->set_value(item.value());
}
}
return client_ptr_->CreateIndex(grpc_index_param);
} catch (std::exception& ex) {
return Status(StatusCode::UnknownError, "Failed to build index: " + std::string(ex.what()));

View File

@ -124,8 +124,7 @@ using TopKQueryResult = std::vector<QueryResult>; ///< Topk hybrid query result
struct IndexParam {
std::string collection_name; ///< Collection name for create index
std::string field_name; ///< Field name
std::string index_name; ///< Index name
std::string extra_params; ///< Extra parameters according to different index type, must be json format
std::string index_params; ///< Extra parameters according to different index type, must be json format
};
/**

View File

@ -443,7 +443,7 @@ class TestSearchBase:
assert len(res) == nq
@pytest.mark.level(2)
def _test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search collection with the given vectors and tags, check the result
@ -919,7 +919,7 @@ class TestSearchDSL(object):
# TODO:
@pytest.mark.level(2)
def _test_query_term_value_all_in(self, connect, collection):
def test_query_term_value_all_in(self, connect, collection):
'''
method: build query with vector and term expr, with all term can be filtered
expected: filter pass
@ -934,7 +934,7 @@ class TestSearchDSL(object):
# TODO:
@pytest.mark.level(2)
def _test_query_term_values_not_in(self, connect, collection):
def test_query_term_values_not_in(self, connect, collection):
'''
method: build query with vector and term expr, with no term can be filtered
expected: filter pass
@ -977,7 +977,7 @@ class TestSearchDSL(object):
# TODO:
@pytest.mark.level(2)
def _test_query_term_values_repeat(self, connect, collection):
def test_query_term_values_repeat(self, connect, collection):
'''
method: build query with vector and term expr, with the same values
expected: filter pass
@ -1030,7 +1030,7 @@ class TestSearchDSL(object):
# TODO
@pytest.mark.level(2)
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
def _test_query_term_wrong_format(self, connect, collection, get_invalid_term):
'''
method: build query with wrong format term
expected: Exception raised