mirror of https://github.com/milvus-io/milvus.git
Fix CreateIndex bug in C++ sdk (#3216)
* Fix TestSearchDSL level 2 bugs Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix QueryTest Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Add annotation in milvus.proto Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix CreateIndex in C++ sdk Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix C++ sdk range test Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix test_search_ip_index_partitions Signed-off-by: fishpenguin <kun.yu@zilliz.com> Co-authored-by: quicksilver <zhifeng.zhang@zilliz.com>pull/3212/head
parent
6b3402434b
commit
0b25f4c015
|
@ -202,9 +202,9 @@ message VectorParam {
|
|||
* "vector": {
|
||||
* "face_img": {
|
||||
* "topk": 10,
|
||||
* "metric_type": "L2",
|
||||
* "query": [],
|
||||
* "params": {
|
||||
* "metric_type": "L2",
|
||||
* "nprobe": 10
|
||||
* }
|
||||
* }
|
||||
|
|
|
@ -88,7 +88,8 @@ SearchTask::OnLoad(LoadType type, uint8_t device_id) {
|
|||
s = Status(SERVER_UNEXPECTED_ERROR, error_msg);
|
||||
}
|
||||
|
||||
return s;
|
||||
job_->status() = s;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string info = "Search task load segment id: " + std::to_string(segment_id_) + " " + type_str + " totally cost";
|
||||
|
@ -133,6 +134,9 @@ SearchTask::OnExecute() {
|
|||
search_job->query_result() = std::make_shared<engine::QueryResult>();
|
||||
search_job->query_result()->row_num_ = nq;
|
||||
}
|
||||
if (vector_param->metric_type == "IP") {
|
||||
ascending_reduce_ = false;
|
||||
}
|
||||
SearchTask::MergeTopkToResultSet(context.query_result_->result_ids_,
|
||||
context.query_result_->result_distances_, spec_k, nq, topk,
|
||||
ascending_reduce_, search_job->query_result());
|
||||
|
|
|
@ -202,7 +202,8 @@ ClientTest::GetEntityByID(const std::string& collection_name, const std::vector<
|
|||
}
|
||||
|
||||
void
|
||||
ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe, const std::string metric_type) {
|
||||
ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe,
|
||||
const std::string metric_type) {
|
||||
nlohmann::json dsl_json, vector_param_json;
|
||||
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, metric_type);
|
||||
|
||||
|
@ -262,8 +263,8 @@ void
|
|||
ClientTest::CreateIndex(const std::string& collection_name, int64_t nlist) {
|
||||
milvus_sdk::TimeRecorder rc("Create index");
|
||||
std::cout << "Wait until create all index done" << std::endl;
|
||||
JSON json_params = {{"nlist", nlist}, {"index_type", "IVF_FLAT"}};
|
||||
milvus::IndexParam index1 = {collection_name, "field_vec", "index_3", json_params.dump()};
|
||||
JSON json_params = {{"index_type", "IVF_FLAT"}, {"metric_type", "L2"}, {"params", {{"nlist", nlist}}}};
|
||||
milvus::IndexParam index1 = {collection_name, "field_vec", json_params.dump()};
|
||||
milvus_sdk::Utils::PrintIndexParam(index1);
|
||||
milvus::Status stat = conn_->CreateIndex(index1);
|
||||
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
|
||||
|
@ -334,6 +335,7 @@ ClientTest::Test() {
|
|||
InsertEntities(collection_name);
|
||||
Flush(collection_name);
|
||||
CountEntities(collection_name);
|
||||
CreateIndex(collection_name, 1024);
|
||||
// GetCollectionStats(collection_name);
|
||||
//
|
||||
BuildVectors(NQ, COLLECTION_DIMENSION);
|
||||
|
|
|
@ -154,8 +154,7 @@ Utils::PrintIndexParam(const milvus::IndexParam& index_param) {
|
|||
BLOCK_SPLITER
|
||||
std::cout << "Index collection name: " << index_param.collection_name << std::endl;
|
||||
std::cout << "Index field name: " << index_param.field_name << std::endl;
|
||||
std::cout << "Index name: " << index_param.index_name << std::endl;
|
||||
std::cout << "Index extra_params: " << index_param.extra_params << std::endl;
|
||||
std::cout << "Index extra_params: " << index_param.index_params << std::endl;
|
||||
BLOCK_SPLITER
|
||||
}
|
||||
|
||||
|
@ -261,26 +260,26 @@ Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& col
|
|||
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
|
||||
std::vector<std::pair<int64_t, milvus::VectorData>> entity_array,
|
||||
milvus::TopKQueryResult& topk_query_result) {
|
||||
/*
|
||||
topk_query_result.clear();
|
||||
/*
|
||||
topk_query_result.clear();
|
||||
|
||||
nlohmann::json dsl_json, vector_param_json;
|
||||
GenDSLJson(dsl_json, vector_param_json);
|
||||
nlohmann::json dsl_json, vector_param_json;
|
||||
GenDSLJson(dsl_json, vector_param_json);
|
||||
|
||||
std::vector<milvus::VectorData> temp_entity_array;
|
||||
for (auto& pair : entity_array) {
|
||||
temp_entity_array.push_back(pair.second);
|
||||
}
|
||||
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
|
||||
std::vector<milvus::VectorData> temp_entity_array;
|
||||
for (auto& pair : entity_array) {
|
||||
temp_entity_array.push_back(pair.second);
|
||||
}
|
||||
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
|
||||
|
||||
JSON json_params = {{"nprobe", nprobe}};
|
||||
milvus_sdk::TimeRecorder rc("Search");
|
||||
JSON json_params = {{"nprobe", nprobe}};
|
||||
milvus_sdk::TimeRecorder rc("Search");
|
||||
|
||||
auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
|
||||
auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
|
||||
|
||||
PrintTopKQueryResult(topk_query_result);
|
||||
// PrintSearchResult(entity_array, topk_query_result);
|
||||
*/
|
||||
PrintTopKQueryResult(topk_query_result);
|
||||
// PrintSearchResult(entity_array, topk_query_result);
|
||||
*/
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -378,8 +377,8 @@ Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, c
|
|||
bool_json["must"].push_back(term_json);
|
||||
|
||||
nlohmann::json comp_json;
|
||||
comp_json["GTE"] = "0";
|
||||
comp_json["LTE"] = "100000";
|
||||
comp_json["GT"] = 0;
|
||||
comp_json["LT"] = 100000;
|
||||
range_json["range"]["field_1"] = comp_json;
|
||||
bool_json["must"].push_back(range_json);
|
||||
|
||||
|
@ -438,4 +437,3 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
|
|||
}
|
||||
|
||||
} // namespace milvus_sdk
|
||||
|
||||
|
|
|
@ -626,10 +626,16 @@ ClientProxy::CreateIndex(const IndexParam& index_param) {
|
|||
::milvus::grpc::IndexParam grpc_index_param;
|
||||
grpc_index_param.set_collection_name(index_param.collection_name);
|
||||
grpc_index_param.set_field_name(index_param.field_name);
|
||||
milvus::grpc::KeyValuePair* kv = grpc_index_param.add_extra_params();
|
||||
grpc_index_param.set_index_name(index_param.index_name);
|
||||
kv->set_key(EXTRA_PARAM_KEY);
|
||||
kv->set_value(index_param.extra_params);
|
||||
JSON json_param = JSON::parse(index_param.index_params);
|
||||
for (auto& item : json_param.items()) {
|
||||
milvus::grpc::KeyValuePair* kv = grpc_index_param.add_extra_params();
|
||||
kv->set_key(item.key());
|
||||
if (item.value().is_object()) {
|
||||
kv->set_value(item.value().dump());
|
||||
} else {
|
||||
kv->set_value(item.value());
|
||||
}
|
||||
}
|
||||
return client_ptr_->CreateIndex(grpc_index_param);
|
||||
} catch (std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "Failed to build index: " + std::string(ex.what()));
|
||||
|
|
|
@ -124,8 +124,7 @@ using TopKQueryResult = std::vector<QueryResult>; ///< Topk hybrid query result
|
|||
struct IndexParam {
|
||||
std::string collection_name; ///< Collection name for create index
|
||||
std::string field_name; ///< Field name
|
||||
std::string index_name; ///< Index name
|
||||
std::string extra_params; ///< Extra parameters according to different index type, must be json format
|
||||
std::string index_params; ///< Extra parameters according to different index type, must be json format
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -443,7 +443,7 @@ class TestSearchBase:
|
|||
assert len(res) == nq
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def _test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
|
||||
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
|
||||
'''
|
||||
target: test basic search fuction, all the search params is corrent, test all index params, and build
|
||||
method: search collection with the given vectors and tags, check the result
|
||||
|
@ -919,7 +919,7 @@ class TestSearchDSL(object):
|
|||
|
||||
# TODO:
|
||||
@pytest.mark.level(2)
|
||||
def _test_query_term_value_all_in(self, connect, collection):
|
||||
def test_query_term_value_all_in(self, connect, collection):
|
||||
'''
|
||||
method: build query with vector and term expr, with all term can be filtered
|
||||
expected: filter pass
|
||||
|
@ -934,7 +934,7 @@ class TestSearchDSL(object):
|
|||
|
||||
# TODO:
|
||||
@pytest.mark.level(2)
|
||||
def _test_query_term_values_not_in(self, connect, collection):
|
||||
def test_query_term_values_not_in(self, connect, collection):
|
||||
'''
|
||||
method: build query with vector and term expr, with no term can be filtered
|
||||
expected: filter pass
|
||||
|
@ -977,7 +977,7 @@ class TestSearchDSL(object):
|
|||
|
||||
# TODO:
|
||||
@pytest.mark.level(2)
|
||||
def _test_query_term_values_repeat(self, connect, collection):
|
||||
def test_query_term_values_repeat(self, connect, collection):
|
||||
'''
|
||||
method: build query with vector and term expr, with the same values
|
||||
expected: filter pass
|
||||
|
@ -1030,7 +1030,7 @@ class TestSearchDSL(object):
|
|||
|
||||
# TODO
|
||||
@pytest.mark.level(2)
|
||||
def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
|
||||
def _test_query_term_wrong_format(self, connect, collection, get_invalid_term):
|
||||
'''
|
||||
method: build query with wrong format term
|
||||
expected: Exception raised
|
||||
|
|
Loading…
Reference in New Issue