Fix search when there are multiple vector fields (#4420)

Signed-off-by: fishpenguin <kun.yu@zilliz.com>
Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
pull/4485/head
yukun 2020-12-10 10:56:48 +08:00 committed by shengjun.li
parent 5ca3a65eca
commit f96352d374
5 changed files with 66 additions and 9 deletions

View File

@ -39,6 +39,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#4272 Program exit abnormally
- \#4302 Setting DSL fields is invalid in restful api, fields are not returned
- \#4329 C++ sdk sdk_binary needs to update
- \#4418 Fix search when there are multiple vector fields
## Feature
- \#4163 Update C++ sdk search interface

View File

@ -73,18 +73,15 @@ SearchReq::OnExecute() {
// step 4: Get field info
std::unordered_map<std::string, engine::DataType> field_types;
auto vector_query = query_ptr_->vectors.begin()->second;
for (auto& schema : fields_schema) {
auto field = schema.first;
field_types.insert(std::make_pair(field->GetName(), field->GetFtype()));
if (field->GetFtype() == engine::DataType::VECTOR_FLOAT ||
field->GetFtype() == engine::DataType::VECTOR_BINARY) {
if (vector_query->field_name == field->GetName() &&
(field->GetFtype() == engine::DataType::VECTOR_FLOAT ||
field->GetFtype() == engine::DataType::VECTOR_BINARY)) {
// check dim
int64_t dimension = field->GetParams()[engine::PARAM_DIMENSION];
auto vector_query = query_ptr_->vectors.begin()->second;
if (vector_query->field_name != field->GetName()) {
return Status(SERVER_INVALID_ARGUMENT,
"DSL vector query field name: " + vector_query->field_name + " is wrong");
}
if (!vector_query->query_vector.binary_data.empty()) {
if (vector_query->query_vector.binary_data.size() !=

View File

@ -76,6 +76,26 @@ ClientTest::ListCollections(std::vector<std::string>& collection_array) {
}
}
void
ClientTest::CreateMultiVecCollection() {
milvus::FieldPtr field1 = std::make_shared<milvus::Field>("release_year", milvus::DataType::INT32, "");
milvus::FieldPtr field2 = std::make_shared<milvus::Field>("duration", milvus::DataType::INT32, "");
nlohmann::json vector_param = {{"dim", COLLECTION_DIMENSION}};
milvus::FieldPtr field3 =
std::make_shared<milvus::Field>("embedding", milvus::DataType::VECTOR_FLOAT, vector_param.dump());
nlohmann::json vector_param_1 = {{"dim", COLLECTION_DIMENSION + COLLECTION_DIMENSION}};
milvus::FieldPtr field4 =
std::make_shared<milvus::Field>("vec", milvus::DataType::VECTOR_FLOAT, vector_param_1.dump());
nlohmann::json json_param;
json_param = {{"auto_id", false}, {"segment_row_limit", 4096}};
milvus::Mapping mapping = {COLLECTION_NAME, {field1, field2, field3, field4}, json_param.dump()};
milvus::Status status = conn_->CreateCollection(mapping);
std::cout << "CreateCollection function call status: " << status.message() << std::endl;
}
void
ClientTest::CreateCollection() {
milvus::FieldPtr field1 = std::make_shared<milvus::Field>("release_year", milvus::DataType::INT32, "");
@ -138,6 +158,31 @@ ClientTest::InsertEntities() {
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
}
void
ClientTest::InsertMultiEntities() {
std::vector<int32_t> duration{208, 226, 252};
std::vector<int32_t> release_year{2001, 2002, 2003};
std::vector<milvus::VectorData> embedding;
milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION, 3, embedding);
std::vector<milvus::VectorData> vec;
milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION * 2, 3, vec);
milvus::FieldValue field_value;
std::unordered_map<std::string, std::vector<int32_t>> int32_value = {{"duration", duration},
{"release_year", release_year}};
std::unordered_map<std::string, std::vector<milvus::VectorData>> vector_value = {{"embedding", embedding},
{"vec", vec}};
field_value.int32_value = int32_value;
field_value.vector_value = vector_value;
std::vector<int64_t> id_array = {1, 2, 3};
auto status = conn_->Insert(COLLECTION_NAME, PARTITION_TAG, field_value, id_array);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
}
void
ClientTest::CountEntities(int64_t& entity_count) {
auto status = conn_->CountEntities(COLLECTION_NAME, entity_count);
@ -215,8 +260,7 @@ ClientTest::SearchEntities() {
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
auto status = conn_->Search(COLLECTION_NAME, partition_tags, dsl_json, json_params.dump(),
topk_query_result);
auto status = conn_->Search(COLLECTION_NAME, partition_tags, dsl_json, json_params.dump(), topk_query_result);
std::cout << " Search function call result: " << std::endl;
milvus_sdk::Utils::PrintTopKQueryResult(topk_query_result);
@ -288,6 +332,7 @@ ClientTest::Test() {
}
CreateCollection();
// CreateMultiVecCollection();
CreatePartition();
std::cout << "--------get collection info--------" << std::endl;
@ -297,6 +342,7 @@ ClientTest::Test() {
std::cout << "\n----------insert----------" << std::endl;
InsertEntities();
// InsertMultiEntities();
int64_t before_flush_counts = 0;
int64_t after_flush_counts = 0;

View File

@ -33,6 +33,12 @@ class ClientTest {
void
CreateCollection();
void
CreateMultiVecCollection();
void
InsertMultiEntities();
void
CreatePartition();

View File

@ -496,6 +496,13 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
}
std::cout << std::endl;
}
if (data.first == "vec") {
std::cout << "- " << data.first << ": ";
for (const auto& v : data.second.float_data) {
std::cout << v << " ";
}
std::cout << std::endl;
}
}
}