mirror of https://github.com/milvus-io/milvus.git
Cherry pick remove translateHits commit to mater (#16436)
Signed-off-by: xige-16 <xi.ge@zilliz.com> Co-authored-by: bigsheeper <yihao.dai@zilliz.com>pull/16441/head
parent
a37479d728
commit
27b4cbc098
|
@ -117,7 +117,7 @@ datatype_is_floating(DataType datatype) {
|
|||
class FieldMeta {
|
||||
public:
|
||||
static const FieldMeta RowIdMeta;
|
||||
FieldMeta(const FieldMeta&) = delete;
|
||||
FieldMeta(const FieldMeta&) = default;
|
||||
FieldMeta(FieldMeta&&) = default;
|
||||
FieldMeta&
|
||||
operator=(const FieldMeta&) = delete;
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <boost/align/aligned_allocator.hpp>
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
#include <NamedType/named_type.hpp>
|
||||
|
||||
#include "pb/schema.pb.h"
|
||||
#include "utils/Types.h"
|
||||
#include "FieldMeta.h"
|
||||
|
||||
namespace milvus {
|
||||
struct SearchResult {
|
||||
SearchResult() = default;
|
||||
SearchResult(int64_t num_queries, int64_t topk) : topk_(topk), num_queries_(num_queries) {
|
||||
auto count = get_row_count();
|
||||
distances_.resize(count);
|
||||
ids_.resize(count);
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_row_count() const {
|
||||
return topk_ * num_queries_;
|
||||
}
|
||||
|
||||
// vector type
|
||||
void
|
||||
AddField(const FieldName& name,
|
||||
const FieldId id,
|
||||
DataType data_type,
|
||||
int64_t dim,
|
||||
std::optional<MetricType> metric_type) {
|
||||
this->AddField(FieldMeta(name, id, data_type, dim, metric_type));
|
||||
}
|
||||
|
||||
// scalar type
|
||||
void
|
||||
AddField(const FieldName& name, const FieldId id, DataType data_type) {
|
||||
this->AddField(FieldMeta(name, id, data_type));
|
||||
}
|
||||
|
||||
void
|
||||
AddField(FieldMeta&& field_meta) {
|
||||
output_fields_meta_.emplace_back(std::move(field_meta));
|
||||
}
|
||||
|
||||
public:
|
||||
int64_t num_queries_;
|
||||
int64_t topk_;
|
||||
std::vector<float> distances_;
|
||||
std::vector<int64_t> ids_; // primary keys
|
||||
|
||||
public:
|
||||
// TODO(gexi): utilize these fields
|
||||
void* segment_;
|
||||
std::vector<int64_t> result_offsets_;
|
||||
std::vector<int64_t> primary_keys_;
|
||||
aligned_vector<char> ids_data_;
|
||||
std::vector<aligned_vector<char>> output_fields_data_;
|
||||
std::vector<FieldMeta> output_fields_meta_;
|
||||
};
|
||||
|
||||
using SearchResultPtr = std::shared_ptr<SearchResult>;
|
||||
using SearchResultOpt = std::optional<SearchResult>;
|
||||
|
||||
struct RetrieveResult {
|
||||
RetrieveResult() = default;
|
||||
|
||||
public:
|
||||
void* segment_;
|
||||
std::vector<int64_t> result_offsets_;
|
||||
std::vector<DataArray> field_data_;
|
||||
};
|
||||
|
||||
using RetrieveResultPtr = std::shared_ptr<RetrieveResult>;
|
||||
using RetrieveResultOpt = std::optional<RetrieveResult>;
|
||||
} // namespace milvus
|
|
@ -61,49 +61,6 @@ constexpr std::false_type always_false{};
|
|||
template <typename T>
|
||||
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 64>>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct SearchResult {
|
||||
SearchResult() = default;
|
||||
SearchResult(int64_t num_queries, int64_t topk) : topk_(topk), num_queries_(num_queries) {
|
||||
auto count = get_row_count();
|
||||
distances_.resize(count);
|
||||
ids_.resize(count);
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_row_count() const {
|
||||
return topk_ * num_queries_;
|
||||
}
|
||||
|
||||
public:
|
||||
int64_t num_queries_;
|
||||
int64_t topk_;
|
||||
std::vector<float> distances_;
|
||||
std::vector<int64_t> ids_;
|
||||
|
||||
public:
|
||||
// TODO(gexi): utilize these fields
|
||||
void* segment_;
|
||||
std::vector<int64_t> result_offsets_;
|
||||
std::vector<int64_t> primary_keys_;
|
||||
std::vector<std::vector<char>> row_data_;
|
||||
};
|
||||
|
||||
using SearchResultPtr = std::shared_ptr<SearchResult>;
|
||||
using SearchResultOpt = std::optional<SearchResult>;
|
||||
|
||||
struct RetrieveResult {
|
||||
RetrieveResult() = default;
|
||||
|
||||
public:
|
||||
void* segment_;
|
||||
std::vector<int64_t> result_offsets_;
|
||||
std::vector<DataArray> field_data_;
|
||||
};
|
||||
|
||||
using RetrieveResultPtr = std::shared_ptr<RetrieveResult>;
|
||||
using RetrieveResultOpt = std::optional<RetrieveResult>;
|
||||
|
||||
namespace impl {
|
||||
// hide identifier name to make auto-completion happy
|
||||
struct FieldIdTag;
|
||||
|
|
|
@ -15,8 +15,15 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "utils/Status.h"
|
||||
#include "common/type_c.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
// SearchResultDataBlobs contains the marshal blobs of many `milvus::proto::schema::SearchResultData`
|
||||
struct SearchResultDataBlobs {
|
||||
std::vector<std::vector<char>> blobs;
|
||||
};
|
||||
|
||||
Status
|
||||
merge_into(int64_t num_queries,
|
||||
int64_t topk,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "common/Consts.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "segcore/Reduce.h"
|
||||
|
||||
using milvus::SearchResult;
|
||||
|
|
|
@ -44,53 +44,38 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult&
|
|||
AssertInfo(plan, "empty plan");
|
||||
auto size = results.distances_.size();
|
||||
AssertInfo(results.ids_.size() == size, "Size of result distances is not equal to size of ids");
|
||||
Assert(results.row_data_.size() == 0);
|
||||
|
||||
std::vector<int64_t> element_sizeofs;
|
||||
std::vector<aligned_vector<char>> blobs;
|
||||
|
||||
// fill row_ids
|
||||
{
|
||||
aligned_vector<char> blob(size * sizeof(int64_t));
|
||||
results.ids_data_.resize(size * sizeof(int64_t));
|
||||
if (plan->schema_.get_is_auto_id()) {
|
||||
bulk_subscript(SystemFieldType::RowId, results.ids_.data(), size, blob.data());
|
||||
bulk_subscript(SystemFieldType::RowId, results.ids_.data(), size, results.ids_data_.data());
|
||||
} else {
|
||||
auto key_offset_opt = get_schema().get_primary_key_offset();
|
||||
AssertInfo(key_offset_opt.has_value(), "Cannot get primary key offset from schema");
|
||||
auto key_offset = key_offset_opt.value();
|
||||
AssertInfo(get_schema()[key_offset].get_data_type() == DataType::INT64,
|
||||
"Primary key field is not INT64 type");
|
||||
bulk_subscript(key_offset, results.ids_.data(), size, blob.data());
|
||||
bulk_subscript(key_offset, results.ids_.data(), size, results.ids_data_.data());
|
||||
}
|
||||
blobs.emplace_back(std::move(blob));
|
||||
element_sizeofs.push_back(sizeof(int64_t));
|
||||
}
|
||||
|
||||
// fill other entries except primary key
|
||||
// fill other entries except primary key by result_offset
|
||||
for (auto field_offset : plan->target_entries_) {
|
||||
auto& field_meta = get_schema()[field_offset];
|
||||
auto element_sizeof = field_meta.get_sizeof();
|
||||
aligned_vector<char> blob(size * element_sizeof);
|
||||
bulk_subscript(field_offset, results.ids_.data(), size, blob.data());
|
||||
blobs.emplace_back(std::move(blob));
|
||||
element_sizeofs.push_back(element_sizeof);
|
||||
}
|
||||
|
||||
auto target_sizeof = std::accumulate(element_sizeofs.begin(), element_sizeofs.end(), 0);
|
||||
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
int64_t element_offset = 0;
|
||||
std::vector<char> target(target_sizeof);
|
||||
for (int loc = 0; loc < blobs.size(); ++loc) {
|
||||
auto element_sizeof = element_sizeofs[loc];
|
||||
auto blob_ptr = blobs[loc].data();
|
||||
auto src = blob_ptr + element_sizeof * i;
|
||||
auto dst = target.data() + element_offset;
|
||||
memcpy(dst, src, element_sizeof);
|
||||
element_offset += element_sizeof;
|
||||
results.output_fields_data_.emplace_back(std::move(blob));
|
||||
if (field_meta.is_vector()) {
|
||||
results.AddField(field_meta.get_name(), field_meta.get_id(), field_meta.get_data_type(),
|
||||
field_meta.get_dim(), field_meta.get_metric_type());
|
||||
} else {
|
||||
results.AddField(field_meta.get_name(), field_meta.get_id(), field_meta.get_data_type());
|
||||
}
|
||||
assert(element_offset == target_sizeof);
|
||||
results.row_data_.emplace_back(std::move(target));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -162,7 +147,7 @@ CreateScalarArrayFrom(const void* data_raw, int64_t count, DataType data_type) {
|
|||
return scalar_array;
|
||||
}
|
||||
|
||||
static std::unique_ptr<DataArray>
|
||||
std::unique_ptr<DataArray>
|
||||
CreateDataArrayFrom(const void* data_raw, int64_t count, const FieldMeta& field_meta) {
|
||||
auto data_type = field_meta.get_data_type();
|
||||
auto data_array = std::make_unique<DataArray>();
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "common/Span.h"
|
||||
#include "common/SystemProperty.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "knowhere/index/vector_index/VecIndex.h"
|
||||
#include "query/Plan.h"
|
||||
#include "query/PlanNode.h"
|
||||
|
@ -172,4 +173,10 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
mutable std::shared_mutex mutex_;
|
||||
};
|
||||
|
||||
static std::unique_ptr<ScalarArray>
|
||||
CreateScalarArrayFrom(const void* data_raw, int64_t count, DataType data_type);
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateDataArrayFrom(const void* data_raw, int64_t count, const FieldMeta& field_meta);
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -13,49 +13,21 @@
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "Reduce.h"
|
||||
#include "common/CGoHelper.h"
|
||||
#include "common/Consts.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "exceptions/EasyAssert.h"
|
||||
#include "log/Log.h"
|
||||
#include "pb/milvus.pb.h"
|
||||
#include "query/Plan.h"
|
||||
#include "segcore/Reduce.h"
|
||||
#include "segcore/ReduceStructure.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "segcore/reduce_c.h"
|
||||
|
||||
using SearchResult = milvus::SearchResult;
|
||||
|
||||
int
|
||||
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids) {
|
||||
auto status = milvus::segcore::merge_into(num_queries, topk, distances, uids, new_distances, new_uids);
|
||||
return status.code();
|
||||
}
|
||||
|
||||
struct MarshaledHitsPerGroup {
|
||||
std::vector<std::string> hits_;
|
||||
std::vector<int64_t> blob_length_;
|
||||
};
|
||||
|
||||
struct MarshaledHits {
|
||||
explicit MarshaledHits(int64_t num_group) {
|
||||
marshaled_hits_.resize(num_group);
|
||||
}
|
||||
|
||||
int
|
||||
get_num_group() {
|
||||
return marshaled_hits_.size();
|
||||
}
|
||||
|
||||
std::vector<MarshaledHitsPerGroup> marshaled_hits_;
|
||||
};
|
||||
|
||||
void
|
||||
DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
|
||||
auto hits = (MarshaledHits*)c_marshaled_hits;
|
||||
delete hits;
|
||||
}
|
||||
|
||||
// void
|
||||
// PrintSearchResult(char* buf, const milvus::SearchResult* result, int64_t seg_idx, int64_t from, int64_t to) {
|
||||
// const int64_t MAXLEN = 32;
|
||||
|
@ -154,6 +126,208 @@ ReduceResultData(std::vector<SearchResult*>& search_results, int64_t nq, int64_t
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
ReorganizeSearchResults(std::vector<SearchResult*>& search_results,
|
||||
int32_t nq,
|
||||
int32_t topK,
|
||||
milvus::aligned_vector<int64_t>& result_ids,
|
||||
std::vector<float>& result_distances,
|
||||
std::vector<milvus::aligned_vector<char>>& result_output_fields_data) {
|
||||
auto num_segments = search_results.size();
|
||||
auto results_count = 0;
|
||||
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
auto search_result = search_results[i];
|
||||
AssertInfo(search_result != nullptr, "null search result when reorganize");
|
||||
AssertInfo(search_result->output_fields_meta_.size() == result_output_fields_data.size(),
|
||||
"illegal fields meta size"
|
||||
", fields_meta_size = " +
|
||||
std::to_string(search_result->output_fields_meta_.size()) +
|
||||
", expected_size = " + std::to_string(result_output_fields_data.size()));
|
||||
auto num_results = search_result->result_offsets_.size();
|
||||
if (num_results == 0) {
|
||||
continue;
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < num_results; j++) {
|
||||
auto loc = search_result->result_offsets_[j];
|
||||
// AssertInfo(loc < nq * topK, "result location of out range, location = " +
|
||||
// std::to_string(loc));
|
||||
// set result ids
|
||||
memcpy(&result_ids[loc], &search_result->ids_data_[j * sizeof(int64_t)], sizeof(int64_t));
|
||||
// set result distances
|
||||
result_distances[loc] = search_result->distances_[j];
|
||||
// set result output fields data
|
||||
for (int k = 0; k < search_result->output_fields_meta_.size(); k++) {
|
||||
auto ele_size = search_result->output_fields_meta_[k].get_sizeof();
|
||||
memcpy(&result_output_fields_data[k][loc * ele_size],
|
||||
&search_result->output_fields_data_[k][j * ele_size], ele_size);
|
||||
}
|
||||
}
|
||||
results_count += num_results;
|
||||
}
|
||||
|
||||
AssertInfo(results_count == nq * topK,
|
||||
"size of reduce result is less than nq * topK"
|
||||
", result_count = " +
|
||||
std::to_string(results_count) + ", nq * topK = " + std::to_string(nq * topK));
|
||||
}
|
||||
|
||||
std::vector<char>
|
||||
GetSearchResultDataSlice(milvus::aligned_vector<int64_t>& result_ids,
|
||||
std::vector<float>& result_distances,
|
||||
std::vector<milvus::aligned_vector<char>>& result_output_fields_data,
|
||||
int32_t nq,
|
||||
int32_t topK,
|
||||
int32_t nq_begin,
|
||||
int32_t nq_end,
|
||||
std::vector<milvus::FieldMeta>& output_fields_meta) {
|
||||
auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>();
|
||||
// set topK and nq
|
||||
search_result_data->set_top_k(topK);
|
||||
search_result_data->set_num_queries(nq);
|
||||
|
||||
auto offset_begin = nq_begin * topK;
|
||||
auto offset_end = nq_end * topK;
|
||||
AssertInfo(offset_begin <= offset_end,
|
||||
"illegal offsets when GetSearchResultDataSlice"
|
||||
", offset_begin = " +
|
||||
std::to_string(offset_begin) + ", offset_end = " + std::to_string(offset_end));
|
||||
AssertInfo(offset_end <= topK * nq,
|
||||
"illegal offset_end when GetSearchResultDataSlice"
|
||||
", offset_end = " +
|
||||
std::to_string(offset_end) + ", nq = " + std::to_string(nq) + ", topK = " + std::to_string(topK));
|
||||
|
||||
// set ids
|
||||
auto proto_ids = std::make_unique<milvus::proto::schema::IDs>();
|
||||
auto ids = std::make_unique<milvus::proto::schema::LongArray>();
|
||||
*ids->mutable_data() = {result_ids.begin() + offset_begin, result_ids.begin() + offset_end};
|
||||
proto_ids->set_allocated_int_id(ids.release());
|
||||
search_result_data->set_allocated_ids(proto_ids.release());
|
||||
AssertInfo(search_result_data->ids().int_id().data_size() == offset_end - offset_begin,
|
||||
"wrong ids size"
|
||||
", size = " +
|
||||
std::to_string(search_result_data->ids().int_id().data_size()) +
|
||||
", expected size = " + std::to_string(offset_end - offset_begin));
|
||||
|
||||
// set scores
|
||||
*search_result_data->mutable_scores() = {result_distances.begin() + offset_begin,
|
||||
result_distances.begin() + offset_end};
|
||||
AssertInfo(search_result_data->scores_size() == offset_end - offset_begin,
|
||||
"wrong scores size"
|
||||
", size = " +
|
||||
std::to_string(search_result_data->scores_size()) +
|
||||
", expected size = " + std::to_string(offset_end - offset_begin));
|
||||
|
||||
// set output fields
|
||||
for (int i = 0; i < result_output_fields_data.size(); i++) {
|
||||
auto& field_meta = output_fields_meta[i];
|
||||
auto field_size = field_meta.get_sizeof();
|
||||
auto array = milvus::segcore::CreateDataArrayFrom(
|
||||
result_output_fields_data[i].data() + offset_begin * field_size, offset_end - offset_begin, field_meta);
|
||||
search_result_data->mutable_fields_data()->AddAllocated(array.release());
|
||||
}
|
||||
|
||||
// SearchResultData to blob
|
||||
auto size = search_result_data->ByteSize();
|
||||
auto buffer = std::vector<char>(size);
|
||||
search_result_data->SerializePartialToArray(buffer.data(), size);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
CStatus
|
||||
Marshal(CSearchResultDataBlobs* cSearchResultDataBlobs,
|
||||
CSearchResult* c_search_results,
|
||||
int32_t num_segments,
|
||||
int32_t* nq_slice_sizes,
|
||||
int32_t num_slices) {
|
||||
try {
|
||||
// parse search results and get topK, nq
|
||||
std::vector<SearchResult*> search_results(num_segments);
|
||||
for (int i = 0; i < num_segments; ++i) {
|
||||
search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
|
||||
}
|
||||
AssertInfo(search_results.size() > 0, "empty search result when Marshal");
|
||||
auto topK = search_results[0]->topk_;
|
||||
auto nq = search_results[0]->num_queries_;
|
||||
|
||||
// init result ids, distances
|
||||
auto result_ids = milvus::aligned_vector<int64_t>(nq * topK);
|
||||
auto result_distances = std::vector<float>(nq * topK);
|
||||
|
||||
// init result output fields data
|
||||
auto& output_fields_meta = search_results[0]->output_fields_meta_;
|
||||
auto num_output_fields = output_fields_meta.size();
|
||||
auto result_output_fields_data = std::vector<milvus::aligned_vector<char>>(num_output_fields);
|
||||
for (int i = 0; i < num_output_fields; i++) {
|
||||
auto size = output_fields_meta[i].get_sizeof();
|
||||
result_output_fields_data[i].resize(size * nq * topK);
|
||||
}
|
||||
|
||||
// Reorganize search results, get result ids, distances and output fields data
|
||||
ReorganizeSearchResults(search_results, nq, topK, result_ids, result_distances, result_output_fields_data);
|
||||
|
||||
// prefix sum, get slices offsets
|
||||
AssertInfo(num_slices > 0, "empty nq_slice_sizes is not allowed");
|
||||
auto slice_offsets_size = num_slices + 1;
|
||||
auto slice_offsets = std::vector<int32_t>(slice_offsets_size);
|
||||
slice_offsets[0] = 0;
|
||||
slice_offsets[1] = nq_slice_sizes[0];
|
||||
for (int i = 2; i < slice_offsets_size; i++) {
|
||||
slice_offsets[i] = slice_offsets[i - 1] + nq_slice_sizes[i - 1];
|
||||
}
|
||||
AssertInfo(slice_offsets[num_slices] == nq,
|
||||
"illegal req sizes"
|
||||
", slice_offsets[last] = " +
|
||||
std::to_string(slice_offsets[num_slices]) + ", nq = " + std::to_string(nq));
|
||||
|
||||
// get search result data blobs by slices
|
||||
auto search_result_data_blobs = std::make_unique<milvus::segcore::SearchResultDataBlobs>();
|
||||
search_result_data_blobs->blobs.resize(num_slices);
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_slices; i++) {
|
||||
auto proto = GetSearchResultDataSlice(result_ids, result_distances, result_output_fields_data, nq, topK,
|
||||
slice_offsets[i], slice_offsets[i + 1], output_fields_meta);
|
||||
search_result_data_blobs->blobs[i] = proto;
|
||||
}
|
||||
|
||||
// set final result ptr
|
||||
*cSearchResultDataBlobs = search_result_data_blobs.release();
|
||||
return milvus::SuccessCStatus();
|
||||
} catch (std::exception& e) {
|
||||
DeleteSearchResultDataBlobs(cSearchResultDataBlobs);
|
||||
return milvus::FailureCStatus(UnexpectedError, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
GetSearchResultDataBlob(CProto* searchResultDataBlob,
|
||||
CSearchResultDataBlobs cSearchResultDataBlobs,
|
||||
int32_t blob_index) {
|
||||
try {
|
||||
auto search_result_data_blobs =
|
||||
reinterpret_cast<milvus::segcore::SearchResultDataBlobs*>(cSearchResultDataBlobs);
|
||||
AssertInfo(blob_index < search_result_data_blobs->blobs.size(), "blob_index out of range");
|
||||
searchResultDataBlob->proto_blob = search_result_data_blobs->blobs[blob_index].data();
|
||||
searchResultDataBlob->proto_size = search_result_data_blobs->blobs[blob_index].size();
|
||||
return milvus::SuccessCStatus();
|
||||
} catch (std::exception& e) {
|
||||
searchResultDataBlob->proto_blob = nullptr;
|
||||
searchResultDataBlob->proto_size = 0;
|
||||
return milvus::FailureCStatus(UnexpectedError, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs) {
|
||||
if (cSearchResultDataBlobs == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto search_result_data_blobs = reinterpret_cast<milvus::segcore::SearchResultDataBlobs*>(cSearchResultDataBlobs);
|
||||
delete search_result_data_blobs;
|
||||
}
|
||||
|
||||
CStatus
|
||||
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_results, int64_t num_segments) {
|
||||
try {
|
||||
|
@ -190,121 +364,3 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
|
|||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments) {
|
||||
try {
|
||||
auto marshaledHits = std::make_unique<MarshaledHits>(1);
|
||||
auto sr = (SearchResult*)c_search_results[0];
|
||||
auto topk = sr->topk_;
|
||||
auto num_queries = sr->num_queries_;
|
||||
|
||||
std::vector<float> result_distances(num_queries * topk);
|
||||
std::vector<std::vector<char>> row_datas(num_queries * topk);
|
||||
|
||||
std::vector<int64_t> counts(num_segments);
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
auto search_result = (SearchResult*)c_search_results[i];
|
||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||
auto size = search_result->result_offsets_.size();
|
||||
if (size == 0) {
|
||||
continue;
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < size; j++) {
|
||||
auto loc = search_result->result_offsets_[j];
|
||||
result_distances[loc] = search_result->distances_[j];
|
||||
row_datas[loc] = search_result->row_data_[j];
|
||||
}
|
||||
counts[i] = size;
|
||||
}
|
||||
|
||||
int64_t total_count = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
total_count += counts[i];
|
||||
}
|
||||
AssertInfo(total_count == num_queries * topk, "the reduces result's size less than total_num_queries*topk");
|
||||
|
||||
MarshaledHitsPerGroup& hits_per_group = (*marshaledHits).marshaled_hits_[0];
|
||||
hits_per_group.hits_.resize(num_queries);
|
||||
hits_per_group.blob_length_.resize(num_queries);
|
||||
std::vector<milvus::proto::milvus::Hits> hits(num_queries);
|
||||
#pragma omp parallel for
|
||||
for (int m = 0; m < num_queries; m++) {
|
||||
for (int n = 0; n < topk; n++) {
|
||||
int64_t result_offset = m * topk + n;
|
||||
hits[m].add_scores(result_distances[result_offset]);
|
||||
auto& row_data = row_datas[result_offset];
|
||||
hits[m].add_row_data(row_data.data(), row_data.size());
|
||||
hits[m].add_ids(*(int64_t*)row_data.data());
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int j = 0; j < num_queries; j++) {
|
||||
auto blob = hits[j].SerializeAsString();
|
||||
hits_per_group.hits_[j] = blob;
|
||||
hits_per_group.blob_length_[j] = blob.size();
|
||||
}
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
auto marshaled_res = (CMarshaledHits)marshaledHits.release();
|
||||
*c_marshaled_hits = marshaled_res;
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
*c_marshaled_hits = nullptr;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetHitsBlobSize(CMarshaledHits c_marshaled_hits) {
|
||||
int64_t total_size = 0;
|
||||
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
||||
auto num_group = marshaled_hits->get_num_group();
|
||||
for (int i = 0; i < num_group; i++) {
|
||||
auto& length_vector = marshaled_hits->marshaled_hits_[i].blob_length_;
|
||||
for (int j = 0; j < length_vector.size(); j++) {
|
||||
total_size += length_vector[j];
|
||||
}
|
||||
}
|
||||
return total_size;
|
||||
}
|
||||
|
||||
void
|
||||
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits) {
|
||||
auto byte_hits = (char*)hits;
|
||||
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
||||
auto num_group = marshaled_hits->get_num_group();
|
||||
int offset = 0;
|
||||
for (int i = 0; i < num_group; i++) {
|
||||
auto& hits = marshaled_hits->marshaled_hits_[i];
|
||||
auto num_queries = hits.hits_.size();
|
||||
for (int j = 0; j < num_queries; j++) {
|
||||
auto blob_size = hits.blob_length_[j];
|
||||
memcpy(byte_hits + offset, hits.hits_[j].data(), blob_size);
|
||||
offset += blob_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) {
|
||||
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
||||
auto& hits = marshaled_hits->marshaled_hits_[group_index].hits_;
|
||||
return hits.size();
|
||||
}
|
||||
|
||||
void
|
||||
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) {
|
||||
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
||||
auto& blob_lens = marshaled_hits->marshaled_hits_[group_index].blob_length_;
|
||||
for (int i = 0; i < blob_lens.size(); i++) {
|
||||
hit_size_peer_query[i] = blob_lens[i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,38 +13,29 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "common/type_c.h"
|
||||
#include "segcore/plan_c.h"
|
||||
#include "segcore/segment_c.h"
|
||||
|
||||
typedef void* CMarshaledHits;
|
||||
|
||||
void
|
||||
DeleteMarshaledHits(CMarshaledHits c_marshaled_hits);
|
||||
|
||||
int
|
||||
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
|
||||
typedef void* CSearchResultDataBlobs;
|
||||
|
||||
CStatus
|
||||
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* search_results, int64_t num_segments);
|
||||
|
||||
CStatus
|
||||
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments);
|
||||
Marshal(CSearchResultDataBlobs* cSearchResultDataBlobs,
|
||||
CSearchResult* c_search_results,
|
||||
int32_t num_segments,
|
||||
int32_t* nq_slice_sizes,
|
||||
int32_t num_slices);
|
||||
|
||||
int64_t
|
||||
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
|
||||
CStatus
|
||||
GetSearchResultDataBlob(CProto* searchResultDataBlob,
|
||||
CSearchResultDataBlobs cSearchResultDataBlobs,
|
||||
int32_t blob_index);
|
||||
|
||||
void
|
||||
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits);
|
||||
|
||||
int64_t
|
||||
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index);
|
||||
|
||||
void
|
||||
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query);
|
||||
DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "query/ExprImpl.h"
|
||||
#include "segcore/Collection.h"
|
||||
#include "segcore/reduce_c.h"
|
||||
#include "segcore/Reduce.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
#include "utils/Types.h"
|
||||
|
||||
|
@ -684,35 +685,6 @@ TEST(CApiTest, GetRowCountTest) {
|
|||
// DeleteSegment(segment);
|
||||
//}
|
||||
|
||||
TEST(CApiTest, MergeInto) {
|
||||
std::vector<int64_t> uids;
|
||||
std::vector<float> distance;
|
||||
|
||||
std::vector<int64_t> new_uids;
|
||||
std::vector<float> new_distance;
|
||||
|
||||
int64_t num_queries = 1;
|
||||
int64_t topk = 2;
|
||||
|
||||
uids.push_back(1);
|
||||
uids.push_back(2);
|
||||
distance.push_back(5);
|
||||
distance.push_back(1000);
|
||||
|
||||
new_uids.push_back(3);
|
||||
new_uids.push_back(4);
|
||||
new_distance.push_back(2);
|
||||
new_distance.push_back(6);
|
||||
|
||||
auto res = MergeInto(num_queries, topk, distance.data(), uids.data(), new_distance.data(), new_uids.data());
|
||||
|
||||
ASSERT_EQ(res, 0);
|
||||
ASSERT_EQ(uids[0], 3);
|
||||
ASSERT_EQ(distance[0], 2);
|
||||
ASSERT_EQ(uids[1], 1);
|
||||
ASSERT_EQ(distance[1], 5);
|
||||
}
|
||||
|
||||
void
|
||||
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
|
||||
auto sr = (SearchResult*)results[0];
|
||||
|
@ -838,89 +810,6 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
|
|||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, Reduce) {
|
||||
auto collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(collection, Growing, -1);
|
||||
|
||||
int N = 10000;
|
||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||
|
||||
int64_t offset;
|
||||
PreInsert(segment, N, &offset);
|
||||
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(ins_res.error_code == Success);
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
int num_queries = 10;
|
||||
auto blob = generate_query_data(num_queries);
|
||||
|
||||
void* plan = nullptr;
|
||||
auto status = CreateSearchPlan(collection, dsl_string, &plan);
|
||||
assert(status.error_code == Success);
|
||||
|
||||
void* placeholderGroup = nullptr;
|
||||
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
|
||||
assert(status.error_code == Success);
|
||||
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
timestamps.clear();
|
||||
timestamps.push_back(1);
|
||||
|
||||
std::vector<CSearchResult> results;
|
||||
CSearchResult res1;
|
||||
CSearchResult res2;
|
||||
auto res = Search(segment, plan, placeholderGroup, timestamps[0], &res1, -1);
|
||||
assert(res.error_code == Success);
|
||||
res = Search(segment, plan, placeholderGroup, timestamps[0], &res2, -1);
|
||||
assert(res.error_code == Success);
|
||||
results.push_back(res1);
|
||||
results.push_back(res2);
|
||||
|
||||
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||
assert(status.error_code == Success);
|
||||
void* reorganize_search_result = nullptr;
|
||||
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
|
||||
assert(status.error_code == Success);
|
||||
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
||||
assert(hits_blob_size > 0);
|
||||
std::vector<char> hits_blob;
|
||||
hits_blob.resize(hits_blob_size);
|
||||
GetHitsBlob(reorganize_search_result, hits_blob.data());
|
||||
assert(hits_blob.data() != nullptr);
|
||||
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
|
||||
assert(num_queries_group == num_queries);
|
||||
std::vector<int64_t> hit_size_per_query;
|
||||
hit_size_per_query.resize(num_queries_group);
|
||||
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
|
||||
assert(hit_size_per_query[0] > 0);
|
||||
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(res1);
|
||||
DeleteSearchResult(res2);
|
||||
DeleteMarshaledHits(reorganize_search_result);
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, ReduceSearchWithExpr) {
|
||||
auto collection = NewCollection(get_default_schema_config());
|
||||
auto segment = NewSegment(collection, Growing, -1);
|
||||
|
@ -941,9 +830,10 @@ TEST(CApiTest, ReduceSearchWithExpr) {
|
|||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
placeholder_tag: "$0">
|
||||
output_field_ids: 100)";
|
||||
|
||||
int topK = 10;
|
||||
int num_queries = 10;
|
||||
auto blob = generate_query_data(num_queries);
|
||||
|
||||
|
@ -971,29 +861,34 @@ TEST(CApiTest, ReduceSearchWithExpr) {
|
|||
results.push_back(res1);
|
||||
results.push_back(res2);
|
||||
|
||||
// 1. reduce
|
||||
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||
assert(status.error_code == Success);
|
||||
void* reorganize_search_result = nullptr;
|
||||
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
|
||||
assert(status.error_code == Success);
|
||||
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
||||
assert(hits_blob_size > 0);
|
||||
std::vector<char> hits_blob;
|
||||
hits_blob.resize(hits_blob_size);
|
||||
GetHitsBlob(reorganize_search_result, hits_blob.data());
|
||||
assert(hits_blob.data() != nullptr);
|
||||
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
|
||||
assert(num_queries_group == num_queries);
|
||||
std::vector<int64_t> hit_size_per_query;
|
||||
hit_size_per_query.resize(num_queries_group);
|
||||
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
|
||||
assert(hit_size_per_query[0] > 0);
|
||||
|
||||
// 2. marshal
|
||||
CSearchResultDataBlobs cSearchResultData;
|
||||
auto req_sizes = std::vector<int32_t>{5, 5};
|
||||
status = Marshal(&cSearchResultData, results.data(), results.size(), req_sizes.data(), req_sizes.size());
|
||||
assert(status.error_code == Success);
|
||||
auto search_result_data_blobs = reinterpret_cast<milvus::segcore::SearchResultDataBlobs*>(cSearchResultData);
|
||||
|
||||
// check result
|
||||
for (int i = 0; i < req_sizes.size(); i++) {
|
||||
milvus::proto::schema::SearchResultData search_result_data;
|
||||
auto suc = search_result_data.ParseFromArray(search_result_data_blobs->blobs[i].data(),
|
||||
search_result_data_blobs->blobs[i].size());
|
||||
assert(suc);
|
||||
assert(search_result_data.top_k() == topK);
|
||||
assert(search_result_data.num_queries() == num_queries);
|
||||
assert(search_result_data.scores().size() == topK * req_sizes[i]);
|
||||
assert(search_result_data.ids().int_id().data_size() == topK * req_sizes[i]);
|
||||
}
|
||||
|
||||
DeleteSearchResultDataBlobs(cSearchResultData);
|
||||
DeleteSearchPlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteSearchResult(res1);
|
||||
DeleteSearchResult(res2);
|
||||
DeleteMarshaledHits(reorganize_search_result);
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
|
|
@ -637,9 +637,9 @@ TEST(Query, FillSegment) {
|
|||
// dispatch here
|
||||
int N = 100000;
|
||||
auto dataset = DataGen(schema, N);
|
||||
const auto std_vec = dataset.get_col<int64_t>(1);
|
||||
const auto std_vfloat_vec = dataset.get_col<float>(0);
|
||||
const auto std_i32_vec = dataset.get_col<int32_t>(2);
|
||||
const auto std_vec = dataset.get_col<int64_t>(1); // ids field
|
||||
const auto std_vfloat_vec = dataset.get_col<float>(0); // vector field
|
||||
const auto std_i32_vec = dataset.get_col<int32_t>(2); // scalar field
|
||||
|
||||
std::vector<std::unique_ptr<SegmentInternalInterface>> segments;
|
||||
segments.emplace_back([&] {
|
||||
|
@ -701,16 +701,20 @@ TEST(Query, FillSegment) {
|
|||
result->result_offsets_.resize(topk * num_queries);
|
||||
segment->FillTargetEntry(plan.get(), *result);
|
||||
|
||||
auto ans = result->row_data_;
|
||||
ASSERT_EQ(ans.size(), topk * num_queries);
|
||||
int64_t std_index = 0;
|
||||
auto fields_data = result->output_fields_data_;
|
||||
auto fields_meta = result->output_fields_meta_;
|
||||
ASSERT_EQ(fields_data.size(), 2);
|
||||
ASSERT_EQ(fields_data.size(), 2);
|
||||
ASSERT_EQ(fields_meta[0].get_sizeof(), sizeof(float) * dim);
|
||||
ASSERT_EQ(fields_meta[1].get_sizeof(), sizeof(int32_t));
|
||||
ASSERT_EQ(fields_data[0].size(), fields_meta[0].get_sizeof() * topk * num_queries);
|
||||
ASSERT_EQ(fields_data[1].size(), fields_meta[1].get_sizeof() * topk * num_queries);
|
||||
|
||||
for (auto& vec : ans) {
|
||||
ASSERT_EQ(vec.size(), sizeof(int64_t) + sizeof(float) * dim + sizeof(int32_t));
|
||||
for (int i = 0; i < topk * num_queries; i++) {
|
||||
int64_t val;
|
||||
memcpy(&val, vec.data(), sizeof(int64_t));
|
||||
memcpy(&val, &result->ids_data_[i * sizeof(int64_t)], sizeof(int64_t));
|
||||
|
||||
auto internal_offset = result->ids_[std_index];
|
||||
auto internal_offset = result->ids_[i];
|
||||
auto std_val = std_vec[internal_offset];
|
||||
auto std_i32 = std_i32_vec[internal_offset];
|
||||
std::vector<float> std_vfloat(dim);
|
||||
|
@ -718,14 +722,16 @@ TEST(Query, FillSegment) {
|
|||
|
||||
ASSERT_EQ(val, std_val) << "io:" << internal_offset;
|
||||
if (val != -1) {
|
||||
// check vector field
|
||||
std::vector<float> vfloat(dim);
|
||||
memcpy(vfloat.data(), &fields_data[0][i * sizeof(float) * dim], dim * sizeof(float));
|
||||
ASSERT_EQ(vfloat, std_vfloat);
|
||||
|
||||
// check int32 field
|
||||
int i32;
|
||||
memcpy(vfloat.data(), vec.data() + sizeof(int64_t), dim * sizeof(float));
|
||||
memcpy(&i32, vec.data() + sizeof(int64_t) + dim * sizeof(float), sizeof(int32_t));
|
||||
ASSERT_EQ(vfloat, std_vfloat) << std_index;
|
||||
ASSERT_EQ(i32, std_i32) << std_index;
|
||||
memcpy(&i32, &fields_data[1][i * sizeof(int32_t)], sizeof(int32_t));
|
||||
ASSERT_EQ(i32, std_i32);
|
||||
}
|
||||
++std_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/util/cgoconverter"
|
||||
)
|
||||
|
||||
// HandleCStatus deals with the error returned from CGO
|
||||
|
@ -54,3 +55,13 @@ func HandleCStatus(status *C.CStatus, extraInfo string) error {
|
|||
log.Warn(logMsg)
|
||||
return errors.New(finalMsg)
|
||||
}
|
||||
|
||||
func CopyCProtoBlob(cProto *C.CProto) []byte {
|
||||
blob := C.GoBytes(unsafe.Pointer(cProto.proto_blob), C.int32_t(cProto.proto_size))
|
||||
return blob
|
||||
}
|
||||
|
||||
func GetCProtoBlob(cProto *C.CProto) []byte {
|
||||
_, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size))
|
||||
return blob
|
||||
}
|
||||
|
|
|
@ -1617,6 +1617,62 @@ func produceSimpleRetrieveMsg(ctx context.Context, queryChannel Channel) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) error {
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
searchResults = append(searchResults, searchResult)
|
||||
|
||||
err := reduceSearchResultsAndFillData(plan, searchResults, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nqOfReqs := []int64{nq / 5, nq / 5, nq / 5, nq / 5, nq / 5}
|
||||
nqPerSlice := nq / 5
|
||||
|
||||
reqSlices, err := getReqSlices(nqOfReqs, nqPerSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := marshal(defaultCollectionID, UniqueID(0), searchResults, 1, reqSlices)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < len(reqSlices); i++ {
|
||||
blob, err := getSearchResultDataBlob(res, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(blob) == 0 {
|
||||
return fmt.Errorf("wrong search result data blobs when checkSearchResult")
|
||||
}
|
||||
|
||||
result := &schemapb.SearchResultData{}
|
||||
err = proto.Unmarshal(blob, result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result.TopK != defaultTopK {
|
||||
return fmt.Errorf("unexpected topK when checkSearchResult")
|
||||
}
|
||||
if result.NumQueries != nq {
|
||||
return fmt.Errorf("unexpected nq when checkSearchResult")
|
||||
}
|
||||
if len(result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data) != int(defaultTopK*nq/5) {
|
||||
return fmt.Errorf("unexpected Ids when checkSearchResult")
|
||||
}
|
||||
if len(result.Scores) != int(defaultTopK*nq/5) {
|
||||
return fmt.Errorf("unexpected Scores when checkSearchResult")
|
||||
}
|
||||
}
|
||||
|
||||
deleteSearchResults(searchResults)
|
||||
deleteSearchResultDataBlobs(res)
|
||||
return nil
|
||||
}
|
||||
|
||||
func initConsumer(ctx context.Context, queryResultChannel Channel) (msgstream.MsgStream, error) {
|
||||
stream, err := genQueryMsgStream(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -1005,11 +1005,6 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
return err
|
||||
}
|
||||
|
||||
schema, err := typeutil.CreateSchemaHelper(collection.schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var plan *SearchPlan
|
||||
if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
expr := searchMsg.SerializedExprPlan
|
||||
|
@ -1041,21 +1036,21 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
}
|
||||
defer searchReq.delete()
|
||||
|
||||
queryNum := searchReq.getNumOfQuery()
|
||||
nq := searchReq.getNumOfQuery()
|
||||
searchRequests := make([]*searchRequest, 0)
|
||||
searchRequests = append(searchRequests, searchReq)
|
||||
|
||||
if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
sp.LogFields(oplog.String("statistical time", "stats start"),
|
||||
oplog.Object("nq", queryNum),
|
||||
oplog.Object("nq", nq),
|
||||
oplog.Object("expr", searchMsg.SerializedExprPlan))
|
||||
} else {
|
||||
sp.LogFields(oplog.String("statistical time", "stats start"),
|
||||
oplog.Object("nq", queryNum),
|
||||
oplog.Object("nq", nq),
|
||||
oplog.Object("dsl", searchMsg.Dsl))
|
||||
}
|
||||
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("search %d(nq=%d, k=%d), msgID = %d", searchMsg.CollectionID, queryNum, topK, searchMsg.ID()))
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("search %d(nq=%d, k=%d), msgID = %d", searchMsg.CollectionID, nq, topK, searchMsg.ID()))
|
||||
|
||||
// get global sealed segments
|
||||
var globalSealedSegments []UniqueID
|
||||
|
@ -1106,7 +1101,7 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
MetricType: plan.getMetricType(),
|
||||
NumQueries: queryNum,
|
||||
NumQueries: nq,
|
||||
TopK: topK,
|
||||
SlicedBlob: nil,
|
||||
SlicedOffset: 1,
|
||||
|
@ -1136,7 +1131,7 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
}
|
||||
|
||||
numSegment := int64(len(searchResults))
|
||||
var marshaledHits *MarshaledHits
|
||||
|
||||
log.Debug("QueryNode reduce data", zap.Int64("msgID", searchMsg.ID()), zap.Int64("numSegment", numSegment))
|
||||
tr.RecordSpan()
|
||||
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
|
||||
|
@ -1146,49 +1141,22 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
log.Error("QueryNode reduce data failed", zap.Int64("msgID", searchMsg.ID()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
marshaledHits, err = reorganizeSearchResults(searchResults, numSegment)
|
||||
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
|
||||
nqOfReqs := []int64{nq}
|
||||
nqPerSlice := nq
|
||||
reqSlices, err := getReqSlices(nqOfReqs, nqPerSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer deleteMarshaledHits(marshaledHits)
|
||||
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
|
||||
blobs, err := marshal(collectionID, searchMsg.ID(), searchResults, int(numSegment), reqSlices)
|
||||
defer deleteSearchResultDataBlobs(blobs)
|
||||
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.QueryNodeID), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
|
||||
var offset int64
|
||||
for index := range searchRequests {
|
||||
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hits := make([][]byte, len(hitBlobSizePeerQuery))
|
||||
for i, len := range hitBlobSizePeerQuery {
|
||||
hits[i] = hitsBlob[offset : offset+len]
|
||||
//test code to checkout marshaled hits
|
||||
//marshaledHit := hitsBlob[offset:offset+len]
|
||||
//unMarshaledHit := milvuspb.Hits{}
|
||||
//err = proto.Unmarshal(marshaledHit, &unMarshaledHit)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//log.Debug("hits msg = ", unMarshaledHit)
|
||||
offset += len
|
||||
}
|
||||
|
||||
// TODO: remove inefficient code in cgo and use SearchResultData directly
|
||||
// TODO: Currently add a translate layer from hits to SearchResultData
|
||||
// TODO: hits marshal and unmarshal is likely bottleneck
|
||||
|
||||
transformed, err := translateHits(schema, searchMsg.OutputFieldsId, hits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
byteBlobs, err := proto.Marshal(transformed)
|
||||
for i := 0; i < len(reqSlices); i++ {
|
||||
blob, err := getSearchResultDataBlob(blobs, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1206,9 +1174,9 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
MetricType: plan.getMetricType(),
|
||||
NumQueries: queryNum,
|
||||
NumQueries: nq,
|
||||
TopK: topK,
|
||||
SlicedBlob: byteBlobs,
|
||||
SlicedBlob: blob,
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
SealedSegmentIDsSearched: sealedSegmentSearched,
|
||||
|
|
|
@ -28,7 +28,11 @@ package querynode
|
|||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
)
|
||||
|
||||
// SearchResult contains a pointer to the search result in C++ memory
|
||||
|
@ -36,10 +40,8 @@ type SearchResult struct {
|
|||
cSearchResult C.CSearchResult
|
||||
}
|
||||
|
||||
// MarshaledHits contains a pointer to the marshaled hits in C++ memory
|
||||
type MarshaledHits struct {
|
||||
cMarshaledHits C.CMarshaledHits
|
||||
}
|
||||
// searchResultDataBlobs is the CSearchResultsDataBlobs in C++
|
||||
type searchResultDataBlobs = C.CSearchResultDataBlobs
|
||||
|
||||
// RetrieveResult contains a pointer to the retrieve result in C++ memory
|
||||
type RetrieveResult struct {
|
||||
|
@ -65,47 +67,59 @@ func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchRes
|
|||
return nil
|
||||
}
|
||||
|
||||
func reorganizeSearchResults(searchResults []*SearchResult, numSegments int64) (*MarshaledHits, error) {
|
||||
func marshal(collectionID UniqueID, msgID UniqueID, searchResults []*SearchResult, numSegments int, reqSlices []int32) (searchResultDataBlobs, error) {
|
||||
log.Debug("start marshal...",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("msgID", msgID),
|
||||
zap.Int32s("reqSlices", reqSlices))
|
||||
|
||||
cSearchResults := make([]C.CSearchResult, 0)
|
||||
for _, res := range searchResults {
|
||||
cSearchResults = append(cSearchResults, res.cSearchResult)
|
||||
}
|
||||
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
|
||||
|
||||
var cNumSegments = C.int64_t(numSegments)
|
||||
var cMarshaledHits C.CMarshaledHits
|
||||
var cNumSegments = C.int32_t(numSegments)
|
||||
var cSlicesPtr = (*C.int32_t)(&reqSlices[0])
|
||||
var cNumSlices = C.int32_t(len(reqSlices))
|
||||
|
||||
status := C.ReorganizeSearchResults(&cMarshaledHits, cSearchResultPtr, cNumSegments)
|
||||
var cSearchResultDataBlobs searchResultDataBlobs
|
||||
|
||||
status := C.Marshal(&cSearchResultDataBlobs, cSearchResultPtr, cNumSegments, cSlicesPtr, cNumSlices)
|
||||
if err := HandleCStatus(&status, "ReorganizeSearchResults failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
|
||||
return cSearchResultDataBlobs, nil
|
||||
}
|
||||
|
||||
func (mh *MarshaledHits) getHitsBlobSize() int64 {
|
||||
res := C.GetHitsBlobSize(mh.cMarshaledHits)
|
||||
return int64(res)
|
||||
func getReqSlices(nqOfReqs []int64, nqPerSlice int64) ([]int32, error) {
|
||||
if nqPerSlice == 0 {
|
||||
return nil, fmt.Errorf("zero nqPerSlice is not allowed")
|
||||
}
|
||||
|
||||
slices := make([]int32, 0)
|
||||
for i := 0; i < len(nqOfReqs); i++ {
|
||||
for j := 0; j < int(nqOfReqs[i]/nqPerSlice); j++ {
|
||||
slices = append(slices, int32(nqPerSlice))
|
||||
}
|
||||
if tailSliceSize := nqOfReqs[i] % nqPerSlice; tailSliceSize > 0 {
|
||||
slices = append(slices, int32(tailSliceSize))
|
||||
}
|
||||
}
|
||||
return slices, nil
|
||||
}
|
||||
|
||||
func (mh *MarshaledHits) getHitsBlob() ([]byte, error) {
|
||||
byteSize := mh.getHitsBlobSize()
|
||||
result := make([]byte, byteSize)
|
||||
cResultPtr := unsafe.Pointer(&result[0])
|
||||
C.GetHitsBlob(mh.cMarshaledHits, cResultPtr)
|
||||
return result, nil
|
||||
func getSearchResultDataBlob(cSearchResultDataBlobs searchResultDataBlobs, blobIndex int) ([]byte, error) {
|
||||
var blob C.CProto
|
||||
status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex))
|
||||
if err := HandleCStatus(&status, "marshal failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return GetCProtoBlob(&blob), nil
|
||||
}
|
||||
|
||||
func (mh *MarshaledHits) hitBlobSizeInGroup(groupOffset int64) ([]int64, error) {
|
||||
cGroupOffset := (C.int64_t)(groupOffset)
|
||||
numQueries := C.GetNumQueriesPerGroup(mh.cMarshaledHits, cGroupOffset)
|
||||
result := make([]int64, int64(numQueries))
|
||||
cResult := (*C.int64_t)(&result[0])
|
||||
C.GetHitSizePerQueries(mh.cMarshaledHits, cGroupOffset, cResult)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func deleteMarshaledHits(hits *MarshaledHits) {
|
||||
C.DeleteMarshaledHits(hits.cMarshaledHits)
|
||||
func deleteSearchResultDataBlobs(cSearchResultDataBlobs searchResultDataBlobs) {
|
||||
C.DeleteSearchResultDataBlobs(cSearchResultDataBlobs)
|
||||
}
|
||||
|
||||
func deleteSearchResults(results []*SearchResult) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"math"
|
||||
"testing"
|
||||
|
@ -29,35 +30,36 @@ import (
|
|||
)
|
||||
|
||||
func TestReduce_AllFunc(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
segmentID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionID, false)
|
||||
nq := int64(10)
|
||||
|
||||
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
|
||||
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true)
|
||||
assert.Nil(t, err)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
const DIM = 16
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
collection, err := node.historical.replica.getCollectionByID(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// start search service
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
||||
var searchRawData1 []byte
|
||||
var searchRawData2 []byte
|
||||
segment, err := node.historical.replica.getSegmentByID(defaultSegmentID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// TODO: replace below by genPlaceholderGroup(nq)
|
||||
vec := genSimpleFloatVectors()
|
||||
var searchRawData []byte
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
searchRawData1 = append(searchRawData1, buf...)
|
||||
}
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
|
||||
searchRawData2 = append(searchRawData2, buf...)
|
||||
searchRawData = append(searchRawData, buf...)
|
||||
}
|
||||
|
||||
placeholderValue := milvuspb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: milvuspb.PlaceholderType_FloatVector,
|
||||
Values: [][]byte{searchRawData1, searchRawData2},
|
||||
Values: [][]byte{},
|
||||
}
|
||||
|
||||
for i := 0; i < int(nq); i++ {
|
||||
placeholderValue.Values = append(placeholderValue.Values, searchRawData)
|
||||
}
|
||||
|
||||
placeholderGroup := milvuspb.PlaceholderGroup{
|
||||
|
@ -69,46 +71,24 @@ func TestReduce_AllFunc(t *testing.T) {
|
|||
log.Print("marshal placeholderGroup failed")
|
||||
}
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
||||
|
||||
plan, err := createSearchPlan(collection, dslString)
|
||||
assert.NoError(t, err)
|
||||
holder, err := parseSearchRequest(plan, placeGroupByte)
|
||||
assert.NoError(t, err)
|
||||
|
||||
placeholderGroups := make([]*searchRequest, 0)
|
||||
placeholderGroups = append(placeholderGroups, holder)
|
||||
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0})
|
||||
assert.Nil(t, err)
|
||||
searchResults = append(searchResults, searchResult)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = reduceSearchResultsAndFillData(plan, searchResults, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
marshaledHits, err := reorganizeSearchResults(searchResults, 1)
|
||||
assert.NotNil(t, marshaledHits)
|
||||
assert.Nil(t, err)
|
||||
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
assert.Nil(t, err)
|
||||
|
||||
var offset int64
|
||||
for index := range placeholderGroups {
|
||||
hitBolbSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
|
||||
assert.Nil(t, err)
|
||||
for _, len := range hitBolbSizePeerQuery {
|
||||
marshaledHit := hitsBlob[offset : offset+len]
|
||||
unMarshaledHit := milvuspb.Hits{}
|
||||
err = proto.Unmarshal(marshaledHit, &unMarshaledHit)
|
||||
assert.Nil(t, err)
|
||||
log.Println("hits msg = ", unMarshaledHit)
|
||||
offset += len
|
||||
}
|
||||
}
|
||||
err = checkSearchResult(nq, plan, searchResult)
|
||||
assert.NoError(t, err)
|
||||
|
||||
plan.delete()
|
||||
holder.delete()
|
||||
deleteSearchResults(searchResults)
|
||||
deleteMarshaledHits(marshaledHits)
|
||||
deleteSegment(segment)
|
||||
deleteCollection(collection)
|
||||
}
|
||||
|
|
|
@ -474,104 +474,61 @@ func TestSegment_segmentDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_segmentSearch(t *testing.T) {
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionID, false)
|
||||
|
||||
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
|
||||
assert.Equal(t, collection.ID(), collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true)
|
||||
assert.Equal(t, segmentID, segment.segmentID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
ids := []int64{1, 2, 3}
|
||||
timestamps := []uint64{0, 0, 0}
|
||||
|
||||
const DIM = 16
|
||||
const N = 3
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
var rawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele))
|
||||
rawData = append(rawData, buf...)
|
||||
}
|
||||
bs := make([]byte, 4)
|
||||
common.Endian.PutUint32(bs, 1)
|
||||
rawData = append(rawData, bs...)
|
||||
var records []*commonpb.Blob
|
||||
for i := 0; i < N; i++ {
|
||||
blob := &commonpb.Blob{
|
||||
Value: rawData,
|
||||
}
|
||||
records = append(records, blob)
|
||||
}
|
||||
|
||||
offset, err := segment.segmentPreInsert(N)
|
||||
assert.Nil(t, err)
|
||||
assert.GreaterOrEqual(t, offset, int64(0))
|
||||
|
||||
err = segment.segmentInsert(offset, &ids, ×tamps, &records)
|
||||
nq := int64(10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
||||
|
||||
collection, err := node.historical.replica.getCollectionByID(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
segment, err := node.historical.replica.getSegmentByID(defaultSegmentID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// TODO: replace below by genPlaceholderGroup(nq)
|
||||
vec := genSimpleFloatVectors()
|
||||
var searchRawData []byte
|
||||
for _, ele := range vec {
|
||||
for i, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele))
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
searchRawData = append(searchRawData, buf...)
|
||||
}
|
||||
|
||||
placeholderValue := milvuspb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: milvuspb.PlaceholderType_FloatVector,
|
||||
Values: [][]byte{searchRawData},
|
||||
Values: [][]byte{},
|
||||
}
|
||||
|
||||
for i := 0; i < int(nq); i++ {
|
||||
placeholderValue.Values = append(placeholderValue.Values, searchRawData)
|
||||
}
|
||||
|
||||
placeholderGroup := milvuspb.PlaceholderGroup{
|
||||
Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue},
|
||||
}
|
||||
|
||||
placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup)
|
||||
placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
||||
if err != nil {
|
||||
log.Print("marshal placeholderGroup failed")
|
||||
}
|
||||
|
||||
travelTimestamp := Timestamp(1020)
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
||||
|
||||
plan, err := createSearchPlan(collection, dslString)
|
||||
assert.NoError(t, err)
|
||||
holder, err := parseSearchRequest(plan, placeHolderGroupBlob)
|
||||
holder, err := parseSearchRequest(plan, placeGroupByte)
|
||||
assert.NoError(t, err)
|
||||
|
||||
placeholderGroups := make([]*searchRequest, 0)
|
||||
placeholderGroups = append(placeholderGroups, holder)
|
||||
|
||||
searchResults := make([]*SearchResult, 0)
|
||||
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{travelTimestamp})
|
||||
assert.Nil(t, err)
|
||||
searchResults = append(searchResults, searchResult)
|
||||
|
||||
///////////////////////////////////
|
||||
numSegment := int64(len(searchResults))
|
||||
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
|
||||
assert.NoError(t, err)
|
||||
marshaledHits, err := reorganizeSearchResults(searchResults, numSegment)
|
||||
assert.NoError(t, err)
|
||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var placeHolderOffset int64
|
||||
for index := range placeholderGroups {
|
||||
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
|
||||
assert.NoError(t, err)
|
||||
hits := make([][]byte, 0)
|
||||
for _, len := range hitBlobSizePeerQuery {
|
||||
hits = append(hits, hitsBlob[placeHolderOffset:placeHolderOffset+len])
|
||||
placeHolderOffset += len
|
||||
}
|
||||
}
|
||||
|
||||
deleteSearchResults(searchResults)
|
||||
deleteMarshaledHits(marshaledHits)
|
||||
///////////////////////////////////
|
||||
err = checkSearchResult(nq, plan, searchResult)
|
||||
assert.NoError(t, err)
|
||||
|
||||
plan.delete()
|
||||
holder.delete()
|
||||
|
|
Loading…
Reference in New Issue