Refine C++ sdk code (#4979)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
pull/4993/head
groot 2021-04-21 16:50:31 +08:00 committed by GitHub
parent afb53bbb22
commit 30f889ae1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 13 additions and 922 deletions

View File

@ -17,4 +17,3 @@ add_subdirectory(simple)
add_subdirectory(partition)
add_subdirectory(binary_vector)
add_subdirectory(qps)
add_subdirectory(hybrid)

View File

@ -1,27 +0,0 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
aux_source_directory(src src_files)
add_executable(sdk_hybrid
main.cpp
${src_files}
${util_files}
)
target_link_libraries(sdk_hybrid
milvus_sdk
pthread
)
install(TARGETS sdk_hybrid DESTINATION bin)

View File

@ -1,73 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <getopt.h>
#include <libgen.h>
#include <cstring>
#include <string>
#include "src/ClientTest.h"
void
print_help(const std::string& app_name);
int
main(int argc, char* argv[]) {
printf("Client start...\n");
std::string app_name = basename(argv[0]);
static struct option long_options[] = {{"server", optional_argument, nullptr, 's'},
{"port", optional_argument, nullptr, 'p'},
{"help", no_argument, nullptr, 'h'},
{nullptr, 0, nullptr, 0}};
int option_index = 0;
std::string address = "127.0.0.1", port = "19530";
app_name = argv[0];
int value;
while ((value = getopt_long(argc, argv, "s:p:h", long_options, &option_index)) != -1) {
switch (value) {
case 's': {
char* address_ptr = strdup(optarg);
address = address_ptr;
free(address_ptr);
break;
}
case 'p': {
char* port_ptr = strdup(optarg);
port = port_ptr;
free(port_ptr);
break;
}
case 'h':
default:
print_help(app_name);
return EXIT_SUCCESS;
}
}
ClientTest test(address, port);
test.TestHybrid();
printf("Client stop...\n");
return 0;
}
void
print_help(const std::string& app_name) {
printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
printf(" Options:\n");
printf(" -s --server Server address, default 127.0.0.1\n");
printf(" -p --port Server port, default 19530\n");
printf(" -h --help Print help information\n");
printf("\n");
}

View File

@ -1,155 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "examples/hybrid/src/ClientTest.h"
#include "examples/utils/TimeRecorder.h"
#include "examples/utils/Utils.h"
#include "include/BooleanQuery.h"
#include "include/MilvusApi.h"
#include <unistd.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include <unordered_map>
namespace {
const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str();
constexpr int64_t COLLECTION_DIMENSION = 512;
constexpr int64_t COLLECTION_INDEX_FILE_SIZE = 1024;
constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2;
constexpr int64_t BATCH_ENTITY_COUNT = 100000;
constexpr int64_t NQ = 5;
constexpr int64_t TOP_K = 10;
constexpr int64_t NPROBE = 32;
constexpr int64_t SEARCH_TARGET = BATCH_ENTITY_COUNT / 2; // change this value, result is different
constexpr int64_t ADD_ENTITY_LOOP = 5;
constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFSQ8;
constexpr int32_t NLIST = 16384;
constexpr uint64_t FIELD_NUM = 3;
} // namespace
ClientTest::ClientTest(const std::string& address, const std::string& port) {
milvus::ConnectParam param = {address, port};
conn_ = milvus::Connection::Create();
milvus::Status stat = conn_->Connect(param);
std::cout << "Connect function call status: " << stat.message() << std::endl;
}
ClientTest::~ClientTest() {
milvus::Status stat = milvus::Connection::Destroy(conn_);
std::cout << "Destroy connection function call status: " << stat.message() << std::endl;
}
void
ClientTest::CreateHybridCollection(const std::string& collection_name) {
milvus::FieldPtr field_ptr1 = std::make_shared<milvus::Field>();
milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
milvus::VectorFieldPtr vec_field_ptr = std::make_shared<milvus::VectorField>();
field_ptr1->field_type = milvus::DataType::INT64;
field_ptr1->field_name = "field_1";
field_ptr2->field_type = milvus::DataType::FLOAT;
field_ptr2->field_name = "field_2";
vec_field_ptr->field_type = milvus::DataType::VECTOR;
vec_field_ptr->field_name = "field_3";
vec_field_ptr->dimension = 128;
std::vector<milvus::FieldPtr> numerica_fields;
std::vector<milvus::VectorFieldPtr> vector_fields;
numerica_fields.emplace_back(field_ptr1);
numerica_fields.emplace_back(field_ptr2);
vector_fields.emplace_back(vec_field_ptr);
milvus::HMapping mapping = {collection_name, numerica_fields, vector_fields};
milvus::Status stat = conn_->CreateHybridCollection(mapping);
std::cout << "CreateHybridCollection function call status: " << stat.message() << std::endl;
}
void
ClientTest::Flush(const std::string& collection_name) {
milvus_sdk::TimeRecorder rc("Flush");
std::vector<std::string> collections = {collection_name};
milvus::Status stat = conn_->Flush(collections);
std::cout << "Flush function call status: " << stat.message() << std::endl;
}
void
ClientTest::InsertHybridEntities(std::string& collection_name, int64_t row_num) {
std::unordered_map<std::string, std::vector<int8_t>> numerica_value;
std::vector<int64_t> value1;
std::vector<double> value2;
value1.resize(row_num);
value2.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
value1[i] = i;
value2[i] = i + row_num;
}
std::vector<int8_t> numerica1(row_num * sizeof(int64_t), 0);
std::vector<int8_t> numerica2(row_num * sizeof(double), 0);
memcpy(numerica1.data(), value1.data(), row_num * sizeof(int64_t));
memcpy(numerica2.data(), value2.data(), row_num * sizeof(double));
numerica_value.insert(std::make_pair("field_1", numerica1));
numerica_value.insert(std::make_pair("field_2", numerica2));
std::unordered_map<std::string, std::vector<milvus::Entity>> vector_value;
std::vector<milvus::Entity> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::BuildEntities(0, row_num, entity_array, record_ids, 128);
}
vector_value.insert(std::make_pair("field_3", entity_array));
milvus::HEntity entity = {row_num, numerica_value, vector_value};
std::vector<uint64_t> id_array;
milvus::Status status = conn_->InsertEntity(collection_name, "", entity, id_array);
std::cout << "InsertHybridEntities function call status: " << status.message() << std::endl;
}
void
ClientTest::HybridSearch(std::string& collection_name) {
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
auto leaf_queries = milvus_sdk::Utils::GenLeafQuery();
// must
auto must_clause = std::make_shared<milvus::BooleanQuery>(milvus::Occur::MUST);
must_clause->AddLeafQuery(leaf_queries[0]);
must_clause->AddLeafQuery(leaf_queries[1]);
must_clause->AddLeafQuery(leaf_queries[2]);
auto query_clause = std::make_shared<milvus::BooleanQuery>();
query_clause->AddBooleanQuery(must_clause);
std::string extra_params;
milvus::Status status =
conn_->HybridSearch(collection_name, partition_tags, query_clause, extra_params, topk_query_result);
for (uint64_t i = 0; i < topk_query_result.size(); ++i) {
std::cout << topk_query_result[i].ids[0] << " --------- " << topk_query_result[i].distances[0] << std::endl;
}
std::cout << "HybridSearch function call status: " << status.message() << std::endl;
}
void
ClientTest::TestHybrid() {
std::string collection_name = "HYBRID_TEST";
CreateHybridCollection(collection_name);
InsertHybridEntities(collection_name, 1000);
Flush(collection_name);
sleep(2);
HybridSearch(collection_name);
}

View File

@ -1,46 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <string>
#include <memory>
#include <utility>
#include <vector>
#include <MilvusApi.h>
class ClientTest {
public:
ClientTest(const std::string&, const std::string&);
~ClientTest();
void
TestHybrid();
private:
void
CreateHybridCollection(const std::string& collection_name);
void
Flush(const std::string&);
void
InsertHybridEntities(std::string&, int64_t);
void
HybridSearch(std::string&);
private:
std::shared_ptr<milvus::Connection> conn_;
std::vector<std::pair<int64_t, milvus::Entity>> search_entity_array_;
std::vector<int64_t> search_id_array_;
};

View File

@ -33,6 +33,7 @@ constexpr int64_t TOP_K = 10;
constexpr int64_t NPROBE = 32;
constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different
constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFSQ8;
constexpr int32_t NLIST = 2048;
constexpr int32_t PARTITION_COUNT = 5;
constexpr int32_t TARGET_PARTITION = 3;
@ -53,7 +54,7 @@ BuildPartitionParam(int32_t index) {
milvus::IndexParam
BuildIndexParam() {
JSON json_params = {{"nlist", 16384}};
JSON json_params = {{"nlist", NLIST}};
milvus::IndexParam index_param = {COLLECTION_NAME, INDEX_TYPE, json_params.dump()};
return index_param;
}

View File

@ -26,7 +26,7 @@ struct TestParameters {
// collection parameters, only works when collection_name_ is empty
int64_t index_type_ = (int64_t)milvus::IndexType::IVFSQ8; // sq8
int64_t index_file_size_ = 1024; // 1024 MB
int64_t nlist_ = 16384;
int64_t nlist_ = 2048;
int64_t metric_type_ = (int64_t)milvus::MetricType::L2; // L2
int64_t dimensions_ = 128;
int64_t row_count_ = 1; // 1 million

View File

@ -10,7 +10,6 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "include/MilvusApi.h"
#include "include/BooleanQuery.h"
#include "examples/utils/TimeRecorder.h"
#include "examples/utils/Utils.h"
#include "examples/simple/src/ClientTest.h"
@ -34,7 +33,7 @@ constexpr int64_t NPROBE = 32;
constexpr int64_t SEARCH_TARGET = BATCH_ENTITY_COUNT / 2; // change this value, result is different
constexpr int64_t ADD_ENTITY_LOOP = 5;
constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFFLAT;
constexpr int32_t NLIST = 16384;
constexpr int32_t NLIST = 2048;
void
PrintEntity(const std::string& tag, const milvus::Entity& entity) {

View File

@ -243,55 +243,4 @@ void ConstructVector(uint64_t nq, uint64_t dimension, std::vector<milvus::Entity
}
}
std::vector<milvus::LeafQueryPtr>
Utils::GenLeafQuery() {
//Construct TermQuery
uint64_t row_num = 1000;
std::vector<int64_t> field_value;
field_value.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
field_value[i] = i;
}
std::vector<int8_t> term_value(row_num * sizeof(int64_t));
memcpy(term_value.data(), field_value.data(), row_num * sizeof(int64_t));
milvus::TermQueryPtr tq = std::make_shared<milvus::TermQuery>();
tq->field_name = "field_1";
tq->field_value = term_value;
//Construct RangeQuery
milvus::CompareExpr ce1 = {milvus::CompareOperator::LTE, "10000"}, ce2 = {milvus::CompareOperator::GTE, "1"};
std::vector<milvus::CompareExpr> ces{ce1, ce2};
milvus::RangeQueryPtr rq = std::make_shared<milvus::RangeQuery>();
rq->field_name = "field_2";
rq->compare_expr = ces;
//Construct VectorQuery
uint64_t NQ = 10;
uint64_t DIMENSION = 128;
uint64_t NPROBE = 32;
milvus::VectorQueryPtr vq = std::make_shared<milvus::VectorQuery>();
ConstructVector(NQ, DIMENSION, vq->query_vector);
vq->field_name = "field_3";
vq->topk = 10;
JSON json_params = {{"nprobe", NPROBE}};
vq->extra_params = json_params.dump();
std::vector<milvus::LeafQueryPtr> lq;
milvus::LeafQueryPtr lq1 = std::make_shared<milvus::LeafQuery>();
milvus::LeafQueryPtr lq2 = std::make_shared<milvus::LeafQuery>();
milvus::LeafQueryPtr lq3 = std::make_shared<milvus::LeafQuery>();
lq.emplace_back(lq1);
lq.emplace_back(lq2);
lq.emplace_back(lq3);
lq1->term_query_ptr = tq;
lq2->range_query_ptr = rq;
lq3->vector_query_ptr = vq;
lq1->query_boost = 1.0;
lq2->query_boost = 2.0;
lq3->query_boost = 3.0;
return lq;
}
} // namespace milvus_sdk

View File

@ -12,7 +12,6 @@
#pragma once
#include "MilvusApi.h"
#include "BooleanQuery.h"
#include "thirdparty/nlohmann/json.hpp"
#include <memory>
@ -71,9 +70,6 @@ class Utils {
const std::vector<std::pair<int64_t, milvus::Entity>>& entity_array,
milvus::TopKQueryResult& topk_query_result,
milvus::MetricType metric_type = milvus::MetricType::INVALID);
static std::vector<milvus::LeafQueryPtr>
GenLeafQuery();
};
} // namespace milvus_sdk

View File

@ -17,7 +17,7 @@
#include "grpc-gen/gen-milvus/milvus.grpc.pb.h"
#define MILVUS_SDK_VERSION "0.10.0";
#define MILVUS_SDK_VERSION "1.1.0";
namespace milvus {
@ -602,198 +602,4 @@ ClientProxy::Compact(const std::string& collection_name) {
}
}
/*******************************New Interface**********************************/
Status
ClientProxy::CreateHybridCollection(const HMapping& mapping) {
try {
::milvus::grpc::Mapping grpc_mapping;
grpc_mapping.set_collection_name(mapping.collection_name);
for (auto field : mapping.numerica_fields) {
::milvus::grpc::FieldParam* field_param = grpc_mapping.add_fields();
field_param->set_name(field->field_name);
field_param->mutable_type()->set_data_type((::milvus::grpc::DataType)field->field_type);
::milvus::grpc::KeyValuePair* kv_pair = field_param->add_extra_params();
kv_pair->set_key("params");
kv_pair->set_value(field->extram_params);
}
for (auto field : mapping.vector_fields) {
::milvus::grpc::FieldParam* field_param = grpc_mapping.add_fields();
field_param->set_name(field->field_name);
field_param->mutable_type()->set_data_type((::milvus::grpc::DataType)field->field_type);
field_param->mutable_type()->mutable_vector_param()->set_dimension(field->dimension);
::milvus::grpc::KeyValuePair* kv_pair = field_param->add_extra_params();
kv_pair->set_key("params");
kv_pair->set_value(field->extram_params);
}
return client_ptr_->CreateHybridCollection(grpc_mapping);
} catch (std::exception& exception) {
return Status(StatusCode::UnknownError, "Failed to create collection: " + std::string(exception.what()));
}
}
void
CopyVectorField(::milvus::grpc::RowRecord* target, const Entity& src) {
if (!src.float_data.empty()) {
auto vector_data = target->mutable_float_data();
vector_data->Resize(static_cast<int>(src.float_data.size()), 0.0);
memcpy(vector_data->mutable_data(), src.float_data.data(), src.float_data.size() * sizeof(float));
}
if (!src.binary_data.empty()) {
target->set_binary_data(src.binary_data.data(), src.binary_data.size());
}
}
Status
ClientProxy::InsertEntity(const std::string& collection_name, const std::string& partition_tag, HEntity& entities,
std::vector<uint64_t>& id_array) {
Status status;
try {
::milvus::grpc::HInsertParam grpc_param;
grpc_param.set_collection_name(collection_name);
grpc_param.set_partition_tag(partition_tag);
std::vector<std::vector<int8_t>> numerica_data;
auto numerica_size = 0;
auto numerica_it = entities.numerica_value.begin();
auto grpc_entity = grpc_param.mutable_entities();
grpc_entity->set_row_num(entities.row_num);
for (; numerica_it != entities.numerica_value.end(); numerica_it++) {
auto name = grpc_entity->add_field_names();
*name = numerica_it->first;
auto size = numerica_it->second.size();
numerica_size += size;
numerica_data.emplace_back(numerica_it->second);
}
std::vector<int8_t> attr_data(numerica_size, 0);
size_t offset = 0;
for (auto data : numerica_data) {
memcpy(attr_data.data() + offset, data.data(), data.size());
offset += data.size();
}
grpc_entity->set_attr_records(attr_data.data(), numerica_size);
auto vector_it = entities.vector_value.begin();
for (; vector_it != entities.vector_value.end(); vector_it++) {
auto name = grpc_param.mutable_entities()->add_field_names();
*name = vector_it->first;
::milvus::grpc::FieldValue* vector_field = grpc_param.mutable_entities()->add_result_values();
for (auto entity : vector_it->second) {
::milvus::grpc::RowRecord* record = vector_field->mutable_vector_value()->add_value();
CopyVectorField(record, entity);
}
}
::milvus::grpc::HEntityIDs entity_ids;
if (!id_array.empty()) {
auto row_ids = grpc_param.mutable_entity_id_array();
row_ids->Resize(static_cast<int>(id_array.size()), -1);
memcpy(row_ids->mutable_data(), id_array.data(), id_array.size() * sizeof(int64_t));
status = client_ptr_->InsertEntities(grpc_param, entity_ids);
} else {
status = client_ptr_->InsertEntities(grpc_param, entity_ids);
id_array.insert(id_array.end(), entity_ids.entity_id_array().begin(), entity_ids.entity_id_array().end());
}
} catch (std::exception& exception) {
return Status(StatusCode::UnknownError, "Failed to create collection: " + std::string(exception.what()));
}
return status;
}
void
WriteQueryToProto(::milvus::grpc::GeneralQuery* general_query, BooleanQueryPtr boolean_query) {
if (!boolean_query->GetBooleanQueries().empty()) {
for (auto query : boolean_query->GetBooleanQueries()) {
auto grpc_boolean_query = general_query->mutable_boolean_query();
grpc_boolean_query->set_occur((::milvus::grpc::Occur)query->GetOccur());
for (auto leaf_query : query->GetLeafQueries()) {
auto grpc_query = grpc_boolean_query->add_general_query();
if (leaf_query->term_query_ptr != nullptr) {
auto term_query = grpc_query->mutable_term_query();
term_query->set_field_name(leaf_query->term_query_ptr->field_name);
term_query->set_boost(leaf_query->query_boost);
term_query->set_values(leaf_query->term_query_ptr->field_value.data(),
leaf_query->term_query_ptr->field_value.size());
}
if (leaf_query->range_query_ptr != nullptr) {
auto range_query = grpc_query->mutable_range_query();
range_query->set_boost(leaf_query->query_boost);
range_query->set_field_name(leaf_query->range_query_ptr->field_name);
for (auto com_expr : leaf_query->range_query_ptr->compare_expr) {
auto grpc_com_expr = range_query->add_operand();
grpc_com_expr->set_operand(com_expr.operand);
grpc_com_expr->set_operator_((milvus::grpc::CompareOperator)com_expr.compare_operator);
}
}
if (leaf_query->vector_query_ptr != nullptr) {
auto vector_query = grpc_query->mutable_vector_query();
vector_query->set_field_name(leaf_query->vector_query_ptr->field_name);
vector_query->set_query_boost(leaf_query->query_boost);
vector_query->set_topk(leaf_query->vector_query_ptr->topk);
for (auto record : leaf_query->vector_query_ptr->query_vector) {
::milvus::grpc::RowRecord* row_record = vector_query->add_records();
CopyRowRecord(row_record, record);
}
auto extra_param = vector_query->add_extra_params();
extra_param->set_key(EXTRA_PARAM_KEY);
extra_param->set_value(leaf_query->vector_query_ptr->extra_params);
}
}
if (!query->GetBooleanQueries().empty()) {
::milvus::grpc::GeneralQuery* next_query = grpc_boolean_query->add_general_query();
WriteQueryToProto(next_query, query);
}
}
}
}
Status
ClientProxy::HybridSearch(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKQueryResult& topk_query_result) {
try {
// convert boolean_query to proto
::milvus::grpc::HSearchParam search_param;
search_param.set_collection_name(collection_name);
for (auto partition : partition_list) {
auto value = search_param.add_partition_tag_array();
*value = partition;
}
auto extra_param = search_param.add_extra_params();
extra_param->set_key("params");
extra_param->set_value(extra_params);
WriteQueryToProto(search_param.mutable_general_query(), boolean_query);
// step 2: search vectors
::milvus::grpc::TopKQueryResult result;
Status status = client_ptr_->HybridSearch(search_param, result);
// step 3: convert result array
topk_query_result.reserve(result.row_num());
int64_t nq = result.row_num();
if (nq == 0) {
return status;
}
int64_t topk = result.ids().size() / nq;
for (int64_t i = 0; i < result.row_num(); i++) {
milvus::QueryResult one_result;
one_result.ids.resize(topk);
one_result.distances.resize(topk);
memcpy(one_result.ids.data(), result.ids().data() + topk * i, topk * sizeof(int64_t));
memcpy(one_result.distances.data(), result.distances().data() + topk * i, topk * sizeof(float));
topk_query_result.emplace_back(one_result);
}
return status;
} catch (std::exception& ex) {
return Status(StatusCode::UnknownError, "Failed to search entities: " + std::string(ex.what()));
}
}
} // namespace milvus

View File

@ -125,20 +125,6 @@ class ClientProxy : public Connection {
Status
Compact(const std::string& collection_name) override;
/*******************************New Interface**********************************/
Status
CreateHybridCollection(const HMapping& mapping) override;
Status
InsertEntity(const std::string& collection_name, const std::string& partition_tag, HEntity& entities,
std::vector<uint64_t>& id_array) override;
Status
HybridSearch(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKQueryResult& topk_query_result) override;
private:
std::shared_ptr<::grpc::Channel> channel_;
std::shared_ptr<GrpcClient> client_ptr_;

View File

@ -491,56 +491,4 @@ GrpcClient::Disconnect() {
return Status::OK();
}
Status
GrpcClient::CreateHybridCollection(milvus::grpc::Mapping& mapping) {
ClientContext context;
::milvus::grpc::Status response;
::grpc::Status grpc_status = stub_->CreateHybridCollection(&context, mapping, &response);
if (!grpc_status.ok()) {
std::cerr << "CreateHybridCollection gRPC failed!" << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message());
}
if (response.error_code() != grpc::SUCCESS) {
std::cerr << response.reason() << std::endl;
return Status(StatusCode::ServerFailed, response.reason());
}
return Status::OK();
}
Status
GrpcClient::InsertEntities(milvus::grpc::HInsertParam& entities, milvus::grpc::HEntityIDs& ids) {
ClientContext context;
::grpc::Status grpc_status = stub_->InsertEntity(&context, entities, &ids);
if (!grpc_status.ok()) {
std::cerr << "InsertEntities gRPC failed!" << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message());
}
if (ids.status().error_code() != grpc::SUCCESS) {
std::cerr << ids.status().reason() << std::endl;
return Status(StatusCode::ServerFailed, ids.status().reason());
}
return Status::OK();
}
Status
GrpcClient::HybridSearch(milvus::grpc::HSearchParam& search_param, milvus::grpc::TopKQueryResult& result) {
ClientContext context;
::grpc::Status grpc_status = stub_->HybridSearch(&context, search_param, &result);
if (!grpc_status.ok()) {
std::cerr << "HybridSearch gRPC failed!" << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message());
}
if (result.status().error_code() != grpc::SUCCESS) {
std::cerr << result.status().reason() << std::endl;
return Status(StatusCode::ServerFailed, result.status().reason());
}
return Status::OK();
}
} // namespace milvus

View File

@ -110,16 +110,6 @@ class GrpcClient {
Status
Disconnect();
/*******************************New Interface**********************************/
Status
CreateHybridCollection(milvus::grpc::Mapping& mapping);
Status
InsertEntities(milvus::grpc::HInsertParam& entities, milvus::grpc::HEntityIDs& ids);
Status
HybridSearch(milvus::grpc::HSearchParam& search_param, milvus::grpc::TopKQueryResult& result);
private:
std::unique_ptr<grpc::MilvusService::Stub> stub_;
};

View File

@ -1,68 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <vector>
#include <string>
# include "GeneralQuery.h"
namespace milvus {
enum class Occur {
INVALID = 0,
MUST,
MUST_NOT,
SHOULD,
};
class BooleanQuery {
public:
BooleanQuery() {}
explicit BooleanQuery(Occur occur) : occur_(occur) {}
void
AddLeafQuery(LeafQueryPtr leaf_query) {
leaf_queries_.emplace_back(leaf_query);
}
void
AddBooleanQuery(std::shared_ptr<BooleanQuery> boolean_query) {
boolean_queries_.emplace_back(boolean_query);
}
std::vector<std::shared_ptr<BooleanQuery>>&
GetBooleanQueries() {
return boolean_queries_;
}
std::vector<LeafQueryPtr>&
GetLeafQueries() {
return leaf_queries_;
}
Occur
GetOccur() {
return occur_;
}
private:
Occur occur_;
std::vector<std::shared_ptr<BooleanQuery>> boolean_queries_;
std::vector<LeafQueryPtr> leaf_queries_;
};
using BooleanQueryPtr = std::shared_ptr<BooleanQuery>;
} // namespace milvus

View File

@ -1,64 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "Status.h"
namespace milvus {
enum class DataType {
INT8 = 1,
INT16 = 2,
INT32 = 3,
INT64 = 4,
STRING = 20,
BOOL = 30,
FLOAT = 40,
DOUBLE = 41,
VECTOR = 100,
UNKNOWN = 9999,
};
// Base struct of all fields
struct Field {
uint64_t field_id; ///< read-only
std::string field_name;
DataType field_type;
float boost;
std::string extram_params;
};
using FieldPtr = std::shared_ptr<Field>;
// DistanceMetric
enum class DistanceMetric {
L2 = 1, // Euclidean Distance
IP = 2, // Cosine Similarity
HAMMING = 3, // Hamming Distance
JACCARD = 4, // Jaccard Distance
TANIMOTO = 5, // Tanimoto Distance
};
// vector field
struct VectorField : Field {
uint64_t dimension;
};
using VectorFieldPtr = std::shared_ptr<VectorField>;
} // namespace milvus

View File

@ -1,96 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <vector>
#include <string>
namespace milvus {
/**
* @brief Entity inserted, currently each entity represent a vector
*/
struct Entity {
std::vector<float> float_data; ///< Vector raw float data
std::vector<uint8_t> binary_data; ///< Vector raw binary data
};
// base class of all queries
struct Sort {
std::string field_name;
int64_t rules; // 0 is inc, 1 is dec
};
struct Query {
std::string field_name;
int64_t from;
int64_t size;
Sort sort;
float min_score;
float boost;
};
enum class CompareOperator {
LT = 0,
LTE,
EQ,
GT,
GTE,
NE,
};
struct QueryColumn {
std::string name;
std::string column_value;
};
struct TermQuery : Query {
std::vector<int8_t> field_value;
};
using TermQueryPtr = std::shared_ptr<TermQuery>;
struct CompareExpr {
CompareOperator compare_operator;
std::string operand;
};
struct RangeQuery : Query {
std::vector<CompareExpr> compare_expr;
};
using RangeQueryPtr = std::shared_ptr<RangeQuery>;
struct RowRecord {
std::vector<float> float_data;
std::vector<uint8_t> binary_data;
};
struct VectorQuery : Query {
uint64_t topk;
float distance_limitation;
float query_boost;
std::vector<Entity> query_vector;
std::string extra_params;
};
using VectorQueryPtr = std::shared_ptr<VectorQuery>;
struct LeafQuery {
TermQueryPtr term_query_ptr;
RangeQueryPtr range_query_ptr;
VectorQueryPtr vector_query_ptr;
float query_boost;
};
using LeafQueryPtr = std::shared_ptr<LeafQuery>;
} // namespace milvus

View File

@ -16,14 +16,20 @@
#include <unordered_map>
#include <vector>
#include "BooleanQuery.h"
#include "Field.h"
#include "Status.h"
/** \brief Milvus SDK namespace
*/
namespace milvus {
/**
* @brief Entity inserted, currently each entity represent a vector
*/
struct Entity {
std::vector<float> float_data; ///< Vector raw float data
std::vector<uint8_t> binary_data; ///< Vector raw binary data
};
/**
* @brief Index Type
*/
@ -113,18 +119,6 @@ struct PartitionParam {
using PartitionTagList = std::vector<std::string>;
struct HMapping {
std::string collection_name;
std::vector<FieldPtr> numerica_fields;
std::vector<VectorFieldPtr> vector_fields;
};
struct HEntity {
int64_t row_num;
std::unordered_map<std::string, std::vector<int8_t>> numerica_value;
std::unordered_map<std::string, std::vector<Entity>> vector_value;
};
/**
* @brief SDK main class
*/
@ -574,20 +568,6 @@ class Connection {
*/
virtual Status
Compact(const std::string& collection_name) = 0;
/*******************************New Interface**********************************/
virtual Status
CreateHybridCollection(const HMapping& mapping) = 0;
virtual Status
InsertEntity(const std::string& collection_name, const std::string& partition_tag, HEntity& entities,
std::vector<uint64_t>& id_array) = 0;
virtual Status
HybridSearch(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKQueryResult& topk_query_result) = 0;
};
} // namespace milvus

View File

@ -198,24 +198,4 @@ ConnectionImpl::Compact(const std::string& collection_name) {
return client_proxy_->Compact(collection_name);
}
/*******************************New Interface**********************************/
Status
ConnectionImpl::CreateHybridCollection(const HMapping& mapping) {
return client_proxy_->CreateHybridCollection(mapping);
}
Status
ConnectionImpl::InsertEntity(const std::string& collection_name, const std::string& partition_tag, HEntity& entities,
std::vector<uint64_t>& id_array) {
return client_proxy_->InsertEntity(collection_name, partition_tag, entities, id_array);
}
Status
ConnectionImpl::HybridSearch(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKQueryResult& topk_query_result) {
return client_proxy_->HybridSearch(collection_name, partition_list, boolean_query, extra_params, topk_query_result);
}
} // namespace milvus

View File

@ -127,20 +127,6 @@ class ConnectionImpl : public Connection {
Status
Compact(const std::string& collection_name) override;
/*******************************New Interface**********************************/
Status
CreateHybridCollection(const HMapping& mapping) override;
Status
InsertEntity(const std::string& collection_name, const std::string& partition_tag, HEntity& entities,
std::vector<uint64_t>& id_array) override;
Status
HybridSearch(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKQueryResult& topk_query_result) override;
private:
std::shared_ptr<ClientProxy> client_proxy_;
};