0.8.0 id=-1 is returned when total count < topk (#2263)

* 0.8.0 id=-1 is returned when total count < topk

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* Fix for comments

Signed-off-by: fishpenguin <kun.yu@zilliz.com>
pull/2271/head
yukun 2020-05-09 10:03:41 +08:00 committed by GitHub
parent cf6be092ab
commit 3852d4ca71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 28 deletions

View File

@ -23,6 +23,7 @@ Please mark all change in change log and use the issue from GitHub
- \#2169 Fix SingleIndexTest.IVFSQHybrid unittest
- \#2194 Fix get collection info failed
- \#2196 Fix server start failed if wal is disabled
- \#2203 0.8.0 id=-1 is returned when total count < topk
- \#2231 Use server_config to define hard-delete delay time for segment files
## Feature

View File

@ -29,13 +29,10 @@ UriCheck(const std::string& uri) {
return (index != std::string::npos);
}
template<typename T>
template <typename T>
void
ConstructSearchParam(const std::string& collection_name,
const std::vector<std::string>& partition_tag_array,
int64_t topk,
const std::string& extra_params,
T& search_param) {
ConstructSearchParam(const std::string& collection_name, const std::vector<std::string>& partition_tag_array,
int64_t topk, const std::string& extra_params, T& search_param) {
search_param.set_collection_name(collection_name);
search_param.set_topk(topk);
milvus::grpc::KeyValuePair* kv = search_param.add_extra_params();
@ -65,12 +62,22 @@ ConstructTopkResult(const ::milvus::grpc::TopKQueryResult& grpc_result, TopKQuer
topk_query_result.reserve(grpc_result.row_num());
int64_t nq = grpc_result.row_num();
int64_t topk = grpc_result.ids().size() / nq;
for (int64_t i = 0; i < grpc_result.row_num(); i++) {
for (int64_t i = 0; i < nq; i++) {
milvus::QueryResult one_result;
one_result.ids.resize(topk);
one_result.distances.resize(topk);
memcpy(one_result.ids.data(), grpc_result.ids().data() + topk * i, topk * sizeof(int64_t));
memcpy(one_result.distances.data(), grpc_result.distances().data() + topk * i, topk * sizeof(float));
int valid_size = one_result.ids.size();
while (valid_size > 0 && one_result.ids[valid_size - 1] == -1) {
valid_size--;
}
if (valid_size != topk) {
one_result.ids.resize(valid_size);
one_result.distances.resize(valid_size);
}
topk_query_result.emplace_back(one_result);
}
}
@ -286,8 +293,7 @@ ClientProxy::GetEntityByID(const std::string& collection_name, int64_t entity_id
}
Status
ClientProxy::GetEntitiesByID(const std::string& collection_name,
const std::vector<int64_t>& id_array,
ClientProxy::GetEntitiesByID(const std::string& collection_name, const std::vector<int64_t>& id_array,
std::vector<Entity>& entities_data) {
try {
entities_data.clear();
@ -358,11 +364,7 @@ ClientProxy::Search(const std::string& collection_name, const std::vector<std::s
try {
// step 1: convert vectors data
::milvus::grpc::SearchParam search_param;
ConstructSearchParam(collection_name,
partition_tag_array,
topk,
extra_params,
search_param);
ConstructSearchParam(collection_name, partition_tag_array, topk, extra_params, search_param);
for (auto& entity : entity_array) {
::milvus::grpc::RowRecord* row_record = search_param.add_query_record_array();
@ -387,16 +389,12 @@ ClientProxy::Search(const std::string& collection_name, const std::vector<std::s
Status
ClientProxy::SearchByID(const std::string& collection_name, const PartitionTagList& partition_tag_array,
const std::vector<int64_t>& id_array, int64_t topk,
const std::string& extra_params, TopKQueryResult& topk_query_result) {
const std::vector<int64_t>& id_array, int64_t topk, const std::string& extra_params,
TopKQueryResult& topk_query_result) {
try {
// step 1: convert vectors data
::milvus::grpc::SearchByIDParam search_param;
ConstructSearchParam(collection_name,
partition_tag_array,
topk,
extra_params,
search_param);
ConstructSearchParam(collection_name, partition_tag_array, topk, extra_params, search_param);
for (auto& id : id_array) {
search_param.add_id_array(id);
@ -664,9 +662,7 @@ CopyVectorField(::milvus::grpc::RowRecord* target, const Entity& src) {
}
Status
ClientProxy::InsertEntity(const std::string& collection_name,
const std::string& partition_tag,
HEntity& entities,
ClientProxy::InsertEntity(const std::string& collection_name, const std::string& partition_tag, HEntity& entities,
std::vector<uint64_t>& id_array) {
Status status;
try {
@ -774,10 +770,8 @@ WriteQueryToProto(::milvus::grpc::GeneralQuery* general_query, BooleanQueryPtr b
}
Status
ClientProxy::HybridSearch(const std::string& collection_name,
const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query,
const std::string& extra_params,
ClientProxy::HybridSearch(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, const std::string& extra_params,
TopKQueryResult& topk_query_result) {
try {
// convert boolean_query to proto