mirror of https://github.com/milvus-io/milvus.git
related: #25324 Search GroupBy function, used to aggregate result entities based on a specific scalar column. several points to mention: 1. Temporarliy, the whole groupby is implemented separated from iterative expr framework **for the first period** 2. In the long term, the groupBy operation will be incorporated into the iterative expr framework:https://github.com/milvus-io/milvus/pull/28166 3. This pr includes some unrelated mocked interface regarding alterIndex due to some unworth-to-mention reasons. All these un-associated content will be removed before the final pr is merged. This version of pr is only for review 4. All other related details were commented in the files comparison Signed-off-by: MrPresent-Han <chun.han@zilliz.com>pull/29933/head
parent
22bb84fa9d
commit
9e2e7157e9
|
@ -21,12 +21,14 @@
|
|||
#include "common/Types.h"
|
||||
#include "knowhere/config.h"
|
||||
namespace milvus {
|
||||
|
||||
struct SearchInfo {
|
||||
int64_t topk_;
|
||||
int64_t round_decimal_;
|
||||
FieldId field_id_;
|
||||
MetricType metric_type_;
|
||||
knowhere::Json search_params_;
|
||||
std::optional<FieldId> group_by_field_id_;
|
||||
};
|
||||
|
||||
using SearchInfoPtr = std::shared_ptr<SearchInfo>;
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
|
||||
#include "common/FieldMeta.h"
|
||||
#include "pb/schema.pb.h"
|
||||
#include "knowhere/index_node.h"
|
||||
|
||||
namespace milvus {
|
||||
struct SearchResult {
|
||||
|
@ -52,6 +53,7 @@ struct SearchResult {
|
|||
// first fill data during search, and then update data after reducing search results
|
||||
std::vector<float> distances_;
|
||||
std::vector<int64_t> seg_offsets_;
|
||||
std::vector<GroupByValueType> group_by_values_;
|
||||
|
||||
// first fill data during fillPrimaryKey, and then update data after reducing search results
|
||||
std::vector<PkType> primary_keys_;
|
||||
|
@ -67,6 +69,10 @@ struct SearchResult {
|
|||
|
||||
// used for reduce, filter invalid pk, get real topks count
|
||||
std::vector<size_t> topk_per_nq_prefix_sum_;
|
||||
|
||||
//knowhere iterators, used for group by or other operators in the future
|
||||
std::optional<std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>>
|
||||
iterators;
|
||||
};
|
||||
|
||||
using SearchResultPtr = std::shared_ptr<SearchResult>;
|
||||
|
|
|
@ -116,6 +116,13 @@ using VectorArray = proto::schema::VectorField;
|
|||
using IdArray = proto::schema::IDs;
|
||||
using InsertData = proto::segcore::InsertRecord;
|
||||
using PkType = std::variant<std::monostate, int64_t, std::string>;
|
||||
using GroupByValueType = std::variant<std::monostate,
|
||||
int8_t,
|
||||
int16_t,
|
||||
int32_t,
|
||||
int64_t,
|
||||
bool,
|
||||
std::string_view>;
|
||||
using ContainsType = proto::plan::JSONContainsExpr_JSONOp;
|
||||
|
||||
inline bool
|
||||
|
|
|
@ -522,6 +522,35 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
|
|||
|
||||
auto num_queries = dataset->GetRows();
|
||||
knowhere::Json search_conf = search_info.search_params_;
|
||||
if (search_info.group_by_field_id_.has_value()) {
|
||||
auto result = std::make_unique<SearchResult>();
|
||||
try {
|
||||
knowhere::expected<
|
||||
std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>>
|
||||
iterators_val =
|
||||
index_.AnnIterator(*dataset, search_conf, bitset);
|
||||
if (iterators_val.has_value()) {
|
||||
result->iterators = iterators_val.value();
|
||||
} else {
|
||||
LOG_ERROR(
|
||||
"Returned knowhere iterator has non-ready iterators "
|
||||
"inside, terminate group_by operation");
|
||||
PanicInfo(ErrorCode::Unsupported,
|
||||
"Returned knowhere iterator has non-ready iterators "
|
||||
"inside, terminate group_by operation");
|
||||
}
|
||||
} catch (const std::runtime_error& e) {
|
||||
LOG_ERROR(
|
||||
"Caught error:{} when trying to initialize ann iterators for "
|
||||
"group_by: "
|
||||
"group_by operation will be terminated",
|
||||
e.what());
|
||||
throw e;
|
||||
}
|
||||
return result;
|
||||
//if the target index doesn't support iterators, directly return empty search result
|
||||
//and the reduce process to filter empty results
|
||||
}
|
||||
auto topk = search_info.topk_;
|
||||
// TODO :: check dim of search data
|
||||
auto final = [&] {
|
||||
|
|
|
@ -26,6 +26,7 @@ set(MILVUS_QUERY_SRCS
|
|||
SearchOnIndex.cpp
|
||||
SearchBruteForce.cpp
|
||||
SubSearchResult.cpp
|
||||
GroupByOperator.cpp
|
||||
PlanProto.cpp
|
||||
)
|
||||
add_library(milvus_query ${MILVUS_QUERY_SRCS})
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "GroupByOperator.h"
|
||||
#include "common/Consts.h"
|
||||
#include "segcore/SegmentSealedImpl.h"
|
||||
#include "Utils.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace query {
|
||||
|
||||
void
|
||||
GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
||||
iterators,
|
||||
const SearchInfo& search_info,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
const segcore::SegmentInternalInterface& segment,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
std::vector<float>& distances) {
|
||||
//0. check segment type, for period-1, only support group by for sealed segments
|
||||
if (!dynamic_cast<const segcore::SegmentSealedImpl*>(&segment)) {
|
||||
LOG_ERROR(
|
||||
"Not support group_by operation for non-sealed segment, "
|
||||
"segment_id:{}",
|
||||
segment.get_segment_id());
|
||||
return;
|
||||
}
|
||||
|
||||
//1. get search meta
|
||||
FieldId group_by_field_id = search_info.group_by_field_id_.value();
|
||||
auto data_type = segment.GetFieldDataType(group_by_field_id);
|
||||
|
||||
switch (data_type) {
|
||||
case DataType::INT8: {
|
||||
auto field_data = segment.chunk_data<int8_t>(group_by_field_id, 0);
|
||||
GroupIteratorsByType<int8_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
auto field_data = segment.chunk_data<int16_t>(group_by_field_id, 0);
|
||||
GroupIteratorsByType<int16_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
auto field_data = segment.chunk_data<int32_t>(group_by_field_id, 0);
|
||||
GroupIteratorsByType<int32_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
auto field_data = segment.chunk_data<int64_t>(group_by_field_id, 0);
|
||||
GroupIteratorsByType<int64_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
break;
|
||||
}
|
||||
case DataType::BOOL: {
|
||||
auto field_data = segment.chunk_data<bool>(group_by_field_id, 0);
|
||||
GroupIteratorsByType<bool>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto field_data =
|
||||
segment.chunk_data<std::string_view>(group_by_field_id, 0);
|
||||
GroupIteratorsByType<std::string_view>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported data type {} for group by operator",
|
||||
data_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GroupIteratorsByType(
|
||||
const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
||||
iterators,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
std::vector<float>& distances,
|
||||
const knowhere::MetricType& metrics_type) {
|
||||
for (auto& iterator : iterators) {
|
||||
GroupIteratorResult<T>(iterator,
|
||||
field_id,
|
||||
topK,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
metrics_type);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GroupIteratorResult(
|
||||
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& offsets,
|
||||
std::vector<float>& distances,
|
||||
const knowhere::MetricType& metrics_type) {
|
||||
//1.
|
||||
std::unordered_map<T, std::pair<int64_t, float>> groupMap;
|
||||
|
||||
//2. do iteration until fill the whole map or run out of all data
|
||||
//note it may enumerate all data inside a segment and can block following
|
||||
//query and search possibly
|
||||
auto dis_closer = [&](float l, float r) {
|
||||
if (PositivelyRelated(metrics_type))
|
||||
return l > r;
|
||||
return l <= r;
|
||||
};
|
||||
while (iterator->HasNext() && groupMap.size() < topK) {
|
||||
auto [offset, dis] = iterator->Next();
|
||||
const T& row_data = field_data.operator[](offset);
|
||||
auto it = groupMap.find(row_data);
|
||||
if (it == groupMap.end()) {
|
||||
groupMap.insert(
|
||||
std::make_pair(row_data, std::make_pair(offset, dis)));
|
||||
} else if (dis_closer(dis, it->second.second)) {
|
||||
it->second = {offset, dis};
|
||||
}
|
||||
}
|
||||
|
||||
//3. sorted based on distances and metrics
|
||||
std::vector<std::pair<T, std::pair<int64_t, float>>> sortedGroupVals(
|
||||
groupMap.begin(), groupMap.end());
|
||||
auto customComparator = [&](const auto& lhs, const auto& rhs) {
|
||||
return dis_closer(lhs.second.second, rhs.second.second);
|
||||
};
|
||||
std::sort(sortedGroupVals.begin(), sortedGroupVals.end(), customComparator);
|
||||
|
||||
//4. save groupBy results
|
||||
for (auto iter = sortedGroupVals.cbegin(); iter != sortedGroupVals.cend();
|
||||
iter++) {
|
||||
group_by_values.emplace_back(iter->first);
|
||||
offsets.push_back(iter->second.first);
|
||||
distances.push_back(iter->second.second);
|
||||
}
|
||||
|
||||
//5. padding topK results, extra memory consumed will be removed when reducing
|
||||
for (std::size_t idx = groupMap.size(); idx < topK; idx++) {
|
||||
offsets.push_back(INVALID_SEG_OFFSET);
|
||||
distances.push_back(0.0);
|
||||
group_by_values.emplace_back(std::monostate{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace query
|
||||
} // namespace milvus
|
|
@ -0,0 +1,61 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/QueryInfo.h"
|
||||
#include "knowhere/index_node.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "common/Span.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace query {
|
||||
void
|
||||
GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
||||
iterators,
|
||||
const SearchInfo& searchInfo,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
const segcore::SegmentInternalInterface& segment,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
std::vector<float>& distances);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GroupIteratorsByType(
|
||||
const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
||||
iterators,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
std::vector<float>& distances,
|
||||
const knowhere::MetricType& metrics_type);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GroupIteratorResult(
|
||||
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& offsets,
|
||||
std::vector<float>& distances,
|
||||
const knowhere::MetricType& metrics_type);
|
||||
|
||||
} // namespace query
|
||||
} // namespace milvus
|
|
@ -203,6 +203,10 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
search_info.search_params_ =
|
||||
nlohmann::json::parse(query_info_proto.search_params());
|
||||
|
||||
if (query_info_proto.group_by_field_id() != 0) {
|
||||
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
|
||||
search_info.group_by_field_id_ = group_by_field_id;
|
||||
}
|
||||
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
|
||||
if (anns_proto.vector_type() ==
|
||||
milvus::proto::plan::VectorType::BinaryVector) {
|
||||
|
|
|
@ -49,18 +49,22 @@ SearchOnSealedIndex(const Schema& schema,
|
|||
auto index_type = vec_index->GetIndexType();
|
||||
return vec_index->Query(ds, search_info, bitset);
|
||||
}();
|
||||
if (final->iterators.has_value()) {
|
||||
result.iterators = std::move(final->iterators);
|
||||
} else {
|
||||
float* distances = final->distances_.data();
|
||||
|
||||
float* distances = final->distances_.data();
|
||||
|
||||
auto total_num = num_queries * topk;
|
||||
if (round_decimal != -1) {
|
||||
const float multiplier = pow(10.0, round_decimal);
|
||||
for (int i = 0; i < total_num; i++) {
|
||||
distances[i] = std::round(distances[i] * multiplier) / multiplier;
|
||||
auto total_num = num_queries * topk;
|
||||
if (round_decimal != -1) {
|
||||
const float multiplier = pow(10.0, round_decimal);
|
||||
for (int i = 0; i < total_num; i++) {
|
||||
distances[i] =
|
||||
std::round(distances[i] * multiplier) / multiplier;
|
||||
}
|
||||
}
|
||||
result.seg_offsets_ = std::move(final->seg_offsets_);
|
||||
result.distances_ = std::move(final->distances_);
|
||||
}
|
||||
result.seg_offsets_ = std::move(final->seg_offsets_);
|
||||
result.distances_ = std::move(final->distances_);
|
||||
result.total_nq_ = num_queries;
|
||||
result.unity_topK_ = topk;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include "log/Log.h"
|
||||
#include "plan/PlanNode.h"
|
||||
#include "exec/Task.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "query/GroupByOperator.h"
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
|
@ -188,7 +190,20 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
|
|||
timestamp_,
|
||||
final_view,
|
||||
search_result);
|
||||
|
||||
if (search_result.iterators.has_value()) {
|
||||
GroupBy(search_result.iterators.value(),
|
||||
node.search_info_,
|
||||
search_result.group_by_values_,
|
||||
*segment,
|
||||
search_result.seg_offsets_,
|
||||
search_result.distances_);
|
||||
AssertInfo(search_result.seg_offsets_.size() ==
|
||||
search_result.group_by_values_.size(),
|
||||
"Wrong state! search_result group_by_values_ size:{} is not "
|
||||
"equal to search_result.seg_offsets.size:{}",
|
||||
search_result.group_by_values_.size(),
|
||||
search_result.seg_offsets_.size());
|
||||
}
|
||||
search_result_opt_ = std::move(search_result);
|
||||
}
|
||||
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
#include "Reduce.h"
|
||||
|
||||
#include <log/Log.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
|
@ -90,6 +88,15 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
|||
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
|
||||
auto& offsets = search_result->seg_offsets_;
|
||||
auto& distances = search_result->distances_;
|
||||
bool need_filter_group_by = !search_result->group_by_values_.empty();
|
||||
if (need_filter_group_by) {
|
||||
AssertInfo(search_result->distances_.size() ==
|
||||
search_result->group_by_values_.size(),
|
||||
"wrong group_by_values size, size:{}, expected size:{} ",
|
||||
search_result->group_by_values_.size(),
|
||||
search_result->distances_.size());
|
||||
}
|
||||
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
for (auto j = 0; j < topK; ++j) {
|
||||
auto index = i * topK + j;
|
||||
|
@ -104,12 +111,17 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
|||
real_topks[i]++;
|
||||
offsets[valid_index] = offsets[index];
|
||||
distances[valid_index] = distances[index];
|
||||
if (need_filter_group_by)
|
||||
search_result->group_by_values_[valid_index] =
|
||||
search_result->group_by_values_[index];
|
||||
valid_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
offsets.resize(valid_index);
|
||||
distances.resize(valid_index);
|
||||
if (need_filter_group_by)
|
||||
search_result->group_by_values_.resize(valid_index);
|
||||
|
||||
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
|
||||
std::partial_sum(real_topks.begin(),
|
||||
|
@ -130,9 +142,8 @@ ReduceHelper::FillPrimaryKey() {
|
|||
FilterInvalidSearchResult(search_result);
|
||||
LOG_DEBUG("the size of search result: {}",
|
||||
search_result->seg_offsets_.size());
|
||||
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
|
||||
if (search_result->get_total_result_count() > 0) {
|
||||
auto segment =
|
||||
static_cast<SegmentInterface*>(search_result->segment_);
|
||||
segment->FillPrimaryKeys(plan_, *search_result);
|
||||
search_results_[valid_index++] = search_result;
|
||||
}
|
||||
|
@ -146,6 +157,15 @@ ReduceHelper::RefreshSearchResult() {
|
|||
for (int i = 0; i < num_segments_; i++) {
|
||||
std::vector<int64_t> real_topks(total_nq_, 0);
|
||||
auto search_result = search_results_[i];
|
||||
bool need_to_handle_group_by = !search_result->group_by_values_.empty();
|
||||
if (need_to_handle_group_by) {
|
||||
AssertInfo(search_result->primary_keys_.size() ==
|
||||
search_result->group_by_values_.size(),
|
||||
"Wrong size for group_by_values size:{}, not equal to "
|
||||
"primary_keys_.size:{}",
|
||||
search_result->group_by_values_.size(),
|
||||
search_result->primary_keys_.size());
|
||||
}
|
||||
if (search_result->result_offsets_.size() != 0) {
|
||||
uint32_t size = 0;
|
||||
for (int j = 0; j < total_nq_; j++) {
|
||||
|
@ -154,6 +174,7 @@ ReduceHelper::RefreshSearchResult() {
|
|||
std::vector<milvus::PkType> primary_keys(size);
|
||||
std::vector<float> distances(size);
|
||||
std::vector<int64_t> seg_offsets(size);
|
||||
std::vector<GroupByValueType> group_by_values(size);
|
||||
|
||||
uint32_t index = 0;
|
||||
for (int j = 0; j < total_nq_; j++) {
|
||||
|
@ -161,6 +182,9 @@ ReduceHelper::RefreshSearchResult() {
|
|||
primary_keys[index] = search_result->primary_keys_[offset];
|
||||
distances[index] = search_result->distances_[offset];
|
||||
seg_offsets[index] = search_result->seg_offsets_[offset];
|
||||
if (need_to_handle_group_by)
|
||||
group_by_values[index] =
|
||||
search_result->group_by_values_[offset];
|
||||
index++;
|
||||
real_topks[j]++;
|
||||
}
|
||||
|
@ -168,6 +192,9 @@ ReduceHelper::RefreshSearchResult() {
|
|||
search_result->primary_keys_.swap(primary_keys);
|
||||
search_result->distances_.swap(distances);
|
||||
search_result->seg_offsets_.swap(seg_offsets);
|
||||
if (need_to_handle_group_by) {
|
||||
search_result->group_by_values_.swap(group_by_values);
|
||||
}
|
||||
}
|
||||
std::partial_sum(real_topks.begin(),
|
||||
real_topks.end(),
|
||||
|
@ -193,8 +220,14 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
|
|||
}
|
||||
pk_set_.clear();
|
||||
pairs_.clear();
|
||||
group_by_val_set_.clear();
|
||||
|
||||
pairs_.reserve(num_segments_);
|
||||
bool need_handle_group_by_values = false;
|
||||
if (num_segments_ > 0) {
|
||||
need_handle_group_by_values =
|
||||
!search_results_[0]->group_by_values_.empty();
|
||||
}
|
||||
for (int i = 0; i < num_segments_; i++) {
|
||||
auto search_result = search_results_[i];
|
||||
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
|
||||
|
@ -206,7 +239,16 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
|
|||
auto distance = search_result->distances_[offset_beg];
|
||||
|
||||
pairs_.emplace_back(
|
||||
primary_key, distance, search_result, i, offset_beg, offset_end);
|
||||
primary_key,
|
||||
distance,
|
||||
search_result,
|
||||
i,
|
||||
offset_beg,
|
||||
offset_end,
|
||||
need_handle_group_by_values
|
||||
? std::make_optional(
|
||||
search_result->group_by_values_.at(offset_beg))
|
||||
: std::nullopt);
|
||||
heap_.push(&pairs_.back());
|
||||
}
|
||||
|
||||
|
@ -229,9 +271,22 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
|
|||
}
|
||||
// remove duplicates
|
||||
if (pk_set_.count(pk) == 0) {
|
||||
pilot->search_result_->result_offsets_.push_back(offset++);
|
||||
final_search_records_[index][qi].push_back(pilot->offset_);
|
||||
pk_set_.insert(pk);
|
||||
bool skip_for_group_by = false;
|
||||
if (need_handle_group_by_values &&
|
||||
pilot->group_by_value_.has_value()) {
|
||||
if (group_by_val_set_.count(pilot->group_by_value_.value()) >
|
||||
0) {
|
||||
skip_for_group_by = true;
|
||||
}
|
||||
}
|
||||
if (!skip_for_group_by) {
|
||||
pilot->search_result_->result_offsets_.push_back(offset++);
|
||||
final_search_records_[index][qi].push_back(pilot->offset_);
|
||||
pk_set_.insert(pk);
|
||||
if (need_handle_group_by_values &&
|
||||
pilot->group_by_value_.has_value())
|
||||
group_by_val_set_.insert(pilot->group_by_value_.value());
|
||||
}
|
||||
} else {
|
||||
// skip entity with same primary key
|
||||
dup_cnt++;
|
||||
|
@ -331,6 +386,10 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
// reserve space for distances
|
||||
search_result_data->mutable_scores()->Resize(result_count, 0);
|
||||
|
||||
//reserve space for group_by_values
|
||||
std::vector<GroupByValueType> group_by_values;
|
||||
group_by_values.resize(result_count);
|
||||
|
||||
// fill pks and distances
|
||||
for (auto qi = nq_begin; qi < nq_end; qi++) {
|
||||
int64_t topk_count = 0;
|
||||
|
@ -377,9 +436,11 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
}
|
||||
}
|
||||
|
||||
// set result distances
|
||||
search_result_data->mutable_scores()->Set(
|
||||
loc, search_result->distances_[ki]);
|
||||
// set group by values
|
||||
if (ki < search_result->group_by_values_.size())
|
||||
group_by_values[loc] = search_result->group_by_values_[ki];
|
||||
// set result offset to fill output fields data
|
||||
result_pairs[loc] = std::make_pair(search_result, ki);
|
||||
}
|
||||
|
@ -388,6 +449,7 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
// update result topKs
|
||||
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
|
||||
}
|
||||
AssembleGroupByValues(search_result_data, group_by_values);
|
||||
|
||||
AssertInfo(search_result_data->scores_size() == result_count,
|
||||
"wrong scores size, size = " +
|
||||
|
@ -417,4 +479,89 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
|||
return buffer;
|
||||
}
|
||||
|
||||
void
|
||||
ReduceHelper::AssembleGroupByValues(
|
||||
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
|
||||
const std::vector<GroupByValueType>& group_by_vals) {
|
||||
auto group_by_field_id = plan_->plan_node_->search_info_.group_by_field_id_;
|
||||
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
|
||||
auto group_by_values_field =
|
||||
std::make_unique<milvus::proto::schema::ScalarField>();
|
||||
auto group_by_field =
|
||||
plan_->schema_.operator[](group_by_field_id.value());
|
||||
DataType group_by_data_type = group_by_field.get_data_type();
|
||||
|
||||
int group_by_val_size = group_by_vals.size();
|
||||
switch (group_by_data_type) {
|
||||
case DataType::INT8: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int8_t val = std::get<int8_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int16_t val = std::get<int16_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int32_t val = std::get<int32_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
auto field_data = group_by_values_field->mutable_long_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int64_t val = std::get<int64_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::BOOL: {
|
||||
auto field_data = group_by_values_field->mutable_bool_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
bool val = std::get<bool>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto field_data = group_by_values_field->mutable_string_data();
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
std::string_view val =
|
||||
std::get<std::string_view>(group_by_vals[idx]);
|
||||
*(field_data->mutable_data()->Add()) = val;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
fmt::format("unsupported datatype for group_by operations ",
|
||||
group_by_data_type));
|
||||
}
|
||||
}
|
||||
|
||||
search_result->mutable_group_by_field_value()->set_type(
|
||||
milvus::proto::schema::DataType(group_by_data_type));
|
||||
search_result->mutable_group_by_field_value()
|
||||
->mutable_scalars()
|
||||
->MergeFrom(*group_by_values_field.get());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -82,6 +82,11 @@ class ReduceHelper {
|
|||
std::vector<char>
|
||||
GetSearchResultDataSlice(int slice_index_);
|
||||
|
||||
void
|
||||
AssembleGroupByValues(
|
||||
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
|
||||
const std::vector<GroupByValueType>& group_by_vals);
|
||||
|
||||
private:
|
||||
std::vector<SearchResult*>& search_results_;
|
||||
milvus::query::Plan* plan_;
|
||||
|
@ -108,6 +113,7 @@ class ReduceHelper {
|
|||
SearchResultPairComparator>
|
||||
heap_;
|
||||
std::unordered_set<milvus::PkType> pk_set_;
|
||||
std::unordered_set<milvus::GroupByValueType> group_by_val_set_;
|
||||
};
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -26,7 +26,8 @@ struct SearchResultPair {
|
|||
milvus::SearchResult* search_result_;
|
||||
int64_t segment_index_;
|
||||
int64_t offset_;
|
||||
int64_t offset_rb_; // right bound
|
||||
int64_t offset_rb_; // right bound
|
||||
std::optional<milvus::GroupByValueType> group_by_value_; //for group_by
|
||||
|
||||
SearchResultPair(milvus::PkType primary_key,
|
||||
float distance,
|
||||
|
@ -34,12 +35,24 @@ struct SearchResultPair {
|
|||
int64_t index,
|
||||
int64_t lb,
|
||||
int64_t rb)
|
||||
: SearchResultPair(
|
||||
primary_key, distance, result, index, lb, rb, std::nullopt) {
|
||||
}
|
||||
|
||||
SearchResultPair(milvus::PkType primary_key,
|
||||
float distance,
|
||||
SearchResult* result,
|
||||
int64_t index,
|
||||
int64_t lb,
|
||||
int64_t rb,
|
||||
std::optional<milvus::GroupByValueType> group_by_value)
|
||||
: primary_key_(std::move(primary_key)),
|
||||
distance_(distance),
|
||||
search_result_(result),
|
||||
segment_index_(index),
|
||||
offset_(lb),
|
||||
offset_rb_(rb) {
|
||||
offset_rb_(rb),
|
||||
group_by_value_(group_by_value) {
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -56,6 +69,9 @@ struct SearchResultPair {
|
|||
if (offset_ < offset_rb_) {
|
||||
primary_key_ = search_result_->primary_keys_.at(offset_);
|
||||
distance_ = search_result_->distances_.at(offset_);
|
||||
if (offset_ < search_result_->group_by_values_.size()) {
|
||||
group_by_value_ = search_result_->group_by_values_.at(offset_);
|
||||
}
|
||||
} else {
|
||||
primary_key_ = INVALID_PK;
|
||||
distance_ = std::numeric_limits<float>::min();
|
||||
|
|
|
@ -418,6 +418,12 @@ SegmentGrowingImpl::num_chunk() const {
|
|||
return upper_div(size, segcore_config_.get_chunk_rows());
|
||||
}
|
||||
|
||||
DataType
|
||||
SegmentGrowingImpl::GetFieldDataType(milvus::FieldId field_id) const {
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
return field_meta.get_data_type();
|
||||
}
|
||||
|
||||
void
|
||||
SegmentGrowingImpl::vector_search(SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
|
|
|
@ -224,6 +224,9 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
const BitsetView& bitset,
|
||||
SearchResult& output) const override;
|
||||
|
||||
DataType
|
||||
GetFieldDataType(FieldId fieldId) const override;
|
||||
|
||||
public:
|
||||
void
|
||||
mask_with_delete(BitsetType& bitset,
|
||||
|
|
|
@ -186,6 +186,9 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||
int64_t chunk_id,
|
||||
const milvus::VariableColumn<std::string>& var_column);
|
||||
|
||||
virtual DataType
|
||||
GetFieldDataType(FieldId fieldId) const = 0;
|
||||
|
||||
public:
|
||||
virtual void
|
||||
vector_search(SearchInfo& search_info,
|
||||
|
|
|
@ -1301,6 +1301,12 @@ SegmentSealedImpl::HasRawData(int64_t field_id) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
DataType
|
||||
SegmentSealedImpl::GetFieldDataType(milvus::FieldId field_id) const {
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
return field_meta.get_data_type();
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
||||
SegmentSealedImpl::search_ids(const IdArray& id_array,
|
||||
Timestamp timestamp) const {
|
||||
|
|
|
@ -86,6 +86,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
|||
bool
|
||||
HasRawData(int64_t field_id) const override;
|
||||
|
||||
DataType
|
||||
GetFieldDataType(FieldId fieldId) const override;
|
||||
|
||||
public:
|
||||
int64_t
|
||||
GetMemoryUsageInBytes() const override;
|
||||
|
|
|
@ -61,6 +61,7 @@ set(MILVUS_TEST_FILES
|
|||
test_storage.cpp
|
||||
test_exec.cpp
|
||||
test_inverted_index.cpp
|
||||
test_group_by.cpp
|
||||
)
|
||||
|
||||
if ( BUILD_DISK_ANN STREQUAL "ON" )
|
||||
|
|
|
@ -112,13 +112,13 @@ Search_Sealed(benchmark::State& state) {
|
|||
if (choice == 0) {
|
||||
// Brute Force
|
||||
} else if (choice == 1) {
|
||||
// ivf
|
||||
// hnsw
|
||||
auto vec = dataset_.get_col<float>(milvus::FieldId(100));
|
||||
auto indexing = GenVecIndexing(N, dim, vec.data());
|
||||
auto indexing = GenVecIndexing(N, dim, vec.data(), knowhere::IndexEnum::INDEX_HNSW);
|
||||
segcore::LoadIndexInfo info;
|
||||
info.index = std::move(indexing);
|
||||
info.field_id = (*schema)[FieldName("fakevec")].get_id().get();
|
||||
info.index_params["index_type"] = "IVF";
|
||||
info.index_params["index_type"] = "HNSW";
|
||||
info.index_params["metric_type"] = knowhere::metric::L2;
|
||||
segment->DropFieldData(milvus::FieldId(100));
|
||||
segment->LoadIndex(info);
|
||||
|
|
|
@ -0,0 +1,480 @@
|
|||
//
|
||||
// Created by zilliz on 2023/12/1.
|
||||
//
|
||||
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "common/Schema.h"
|
||||
#include "segcore/SegmentSealedImpl.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
#include "query/Plan.h"
|
||||
#include "segcore/segment_c.h"
|
||||
#include "segcore/reduce_c.h"
|
||||
#include "test_utils/c_api_test_utils.h"
|
||||
#include "segcore/plan_c.h"
|
||||
|
||||
using namespace milvus;
|
||||
using namespace milvus::segcore;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::storage;
|
||||
|
||||
const char* METRICS_TYPE = "metric_type";
|
||||
|
||||
|
||||
void
|
||||
prepareSegmentSystemFieldData(const std::unique_ptr<SegmentSealed>& segment,
|
||||
size_t row_count,
|
||||
GeneratedData& data_set){
|
||||
auto field_data =
|
||||
std::make_shared<milvus::FieldData<int64_t>>(DataType::INT64);
|
||||
field_data->FillFieldData(data_set.row_ids_.data(), row_count);
|
||||
auto field_data_info = FieldDataInfo{
|
||||
RowFieldID.get(), row_count, std::vector<milvus::FieldDataPtr>{field_data}};
|
||||
segment->LoadFieldData(RowFieldID, field_data_info);
|
||||
|
||||
field_data =
|
||||
std::make_shared<milvus::FieldData<int64_t>>(DataType::INT64);
|
||||
field_data->FillFieldData(data_set.timestamps_.data(), row_count);
|
||||
field_data_info =
|
||||
FieldDataInfo{TimestampFieldID.get(),
|
||||
row_count,
|
||||
std::vector<milvus::FieldDataPtr>{field_data}};
|
||||
segment->LoadFieldData(TimestampFieldID, field_data_info);
|
||||
}
|
||||
|
||||
TEST(GroupBY, Normal2){
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
//0. prepare schema
|
||||
int dim = 64;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
|
||||
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
|
||||
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
|
||||
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
|
||||
auto str_fid = schema->AddDebugField("string1", DataType::VARCHAR);
|
||||
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
|
||||
schema->set_primary_field_id(str_fid);
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
size_t N = 100;
|
||||
|
||||
//2. load raw data
|
||||
auto raw_data = DataGen(schema, N);
|
||||
auto fields = schema->get_fields();
|
||||
for (auto field_data : raw_data.raw_->fields_data()) {
|
||||
int64_t field_id = field_data.field_id();
|
||||
|
||||
auto info = FieldDataInfo(field_data.field_id(), N);
|
||||
auto field_meta = fields.at(FieldId(field_id));
|
||||
info.channel->push(
|
||||
CreateFieldDataFromDataArray(N, &field_data, field_meta));
|
||||
info.channel->close();
|
||||
|
||||
segment->LoadFieldData(FieldId(field_id), info);
|
||||
}
|
||||
prepareSegmentSystemFieldData(segment, N, raw_data);
|
||||
|
||||
//3. load index
|
||||
auto vector_data = raw_data.get_col<float>(vec_fid);
|
||||
auto indexing = GenVecIndexing(N, dim, vector_data.data(), knowhere::IndexEnum::INDEX_HNSW);
|
||||
LoadIndexInfo load_index_info;
|
||||
load_index_info.field_id = vec_fid.get();
|
||||
load_index_info.index = std::move(indexing);
|
||||
load_index_info.index_params[METRICS_TYPE] = knowhere::metric::L2;
|
||||
segment->LoadIndex(load_index_info);
|
||||
|
||||
//4. search group by int8
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 100
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 101
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
>)";
|
||||
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result = segment->Search(plan.get(), ph_group.get());
|
||||
auto& group_by_values = search_result->group_by_values_;
|
||||
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
|
||||
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
|
||||
|
||||
int size = group_by_values.size();
|
||||
std::unordered_set<int8_t> i8_set;
|
||||
float lastDistance = 0.0;
|
||||
for(size_t i = 0; i < size; i++){
|
||||
if(std::holds_alternative<int8_t>(group_by_values[i])){
|
||||
int8_t g_val = std::get<int8_t>(group_by_values[i]);
|
||||
ASSERT_FALSE(i8_set.count(g_val)>0);//no repetition on groupBy field
|
||||
i8_set.insert(g_val);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
} else {
|
||||
//check padding
|
||||
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
|
||||
ASSERT_EQ(search_result->distances_[i], 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//4. search group by int16
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 100
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 102
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
>)";
|
||||
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result = segment->Search(plan.get(), ph_group.get());
|
||||
auto& group_by_values = search_result->group_by_values_;
|
||||
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
|
||||
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
|
||||
|
||||
int size = group_by_values.size();
|
||||
std::unordered_set<int16_t> i16_set;
|
||||
float lastDistance = 0.0;
|
||||
for(size_t i = 0; i < size; i++){
|
||||
if(std::holds_alternative<int16_t>(group_by_values[i])){
|
||||
int16_t g_val = std::get<int16_t>(group_by_values[i]);
|
||||
ASSERT_FALSE(i16_set.count(g_val)>0);//no repetition on groupBy field
|
||||
i16_set.insert(g_val);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
} else {
|
||||
//check padding
|
||||
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
|
||||
ASSERT_EQ(search_result->distances_[i], 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//4. search group by int32
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 100
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 103
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
>)";
|
||||
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result = segment->Search(plan.get(), ph_group.get());
|
||||
auto& group_by_values = search_result->group_by_values_;
|
||||
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
|
||||
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
|
||||
|
||||
int size = group_by_values.size();
|
||||
std::unordered_set<int32_t> i32_set;
|
||||
float lastDistance = 0.0;
|
||||
for(size_t i = 0; i < size; i++){
|
||||
if(std::holds_alternative<int32_t>(group_by_values[i])){
|
||||
int16_t g_val = std::get<int32_t>(group_by_values[i]);
|
||||
ASSERT_FALSE(i32_set.count(g_val)>0);//no repetition on groupBy field
|
||||
i32_set.insert(g_val);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
} else {
|
||||
//check padding
|
||||
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
|
||||
ASSERT_EQ(search_result->distances_[i], 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//4. search group by int64
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 100
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 104
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
>)";
|
||||
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result = segment->Search(plan.get(), ph_group.get());
|
||||
auto& group_by_values = search_result->group_by_values_;
|
||||
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
|
||||
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
|
||||
|
||||
int size = group_by_values.size();
|
||||
std::unordered_set<int64_t> i64_set;
|
||||
float lastDistance = 0.0;
|
||||
for(size_t i = 0; i < size; i++){
|
||||
if(std::holds_alternative<int64_t>(group_by_values[i])){
|
||||
int16_t g_val = std::get<int64_t>(group_by_values[i]);
|
||||
ASSERT_FALSE(i64_set.count(g_val)>0);//no repetition on groupBy field
|
||||
i64_set.insert(g_val);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
} else {
|
||||
//check padding
|
||||
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
|
||||
ASSERT_EQ(search_result->distances_[i], 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//4. search group by string
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 100
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 105
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
>)";
|
||||
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result = segment->Search(plan.get(), ph_group.get());
|
||||
auto& group_by_values = search_result->group_by_values_;
|
||||
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
|
||||
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
|
||||
|
||||
int size = group_by_values.size();
|
||||
std::unordered_set<std::string_view> strs_set;
|
||||
float lastDistance = 0.0;
|
||||
for(size_t i = 0; i < size; i++){
|
||||
if(std::holds_alternative<std::string_view>(group_by_values[i])){
|
||||
std::string_view g_val = std::get<std::string_view>(group_by_values[i]);
|
||||
ASSERT_FALSE(strs_set.count(g_val)>0);//no repetition on groupBy field
|
||||
strs_set.insert(g_val);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
} else {
|
||||
//check padding
|
||||
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
|
||||
ASSERT_EQ(search_result->distances_[i], 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//4. search group by bool
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 100
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 106
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
>)";
|
||||
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto search_result = segment->Search(plan.get(), ph_group.get());
|
||||
auto& group_by_values = search_result->group_by_values_;
|
||||
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
|
||||
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
|
||||
|
||||
int size = group_by_values.size();
|
||||
std::unordered_set<bool> bools_set;
|
||||
int boolValCount = 0;
|
||||
float lastDistance = 0.0;
|
||||
for(size_t i = 0; i < size; i++){
|
||||
if(std::holds_alternative<bool>(group_by_values[i])){
|
||||
bool g_val = std::get<bool>(group_by_values[i]);
|
||||
ASSERT_FALSE(bools_set.count(g_val)>0);//no repetition on groupBy field
|
||||
bools_set.insert(g_val);
|
||||
boolValCount+=1;
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
} else {
|
||||
//check padding
|
||||
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
|
||||
ASSERT_EQ(search_result->distances_[i], 0.0);
|
||||
}
|
||||
ASSERT_TRUE(boolValCount<=2);//bool values cannot exceed two
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GroupBY, Reduce){
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
//0. prepare schema
|
||||
int dim = 64;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
|
||||
schema->set_primary_field_id(int64_fid);
|
||||
auto segment1 = CreateSealedSegment(schema);
|
||||
auto segment2 = CreateSealedSegment(schema);
|
||||
|
||||
//1. load raw data
|
||||
size_t N = 100;
|
||||
uint64_t seed = 512;
|
||||
uint64_t ts_offset = 0;
|
||||
int repeat_count_1 = 2;
|
||||
int repeat_count_2 = 5;
|
||||
auto raw_data1 = DataGen(schema, N, seed, ts_offset, repeat_count_1);
|
||||
auto raw_data2 = DataGen(schema, N, seed, ts_offset, repeat_count_2);
|
||||
|
||||
auto fields = schema->get_fields();
|
||||
//load segment1 raw data
|
||||
for (auto field_data : raw_data1.raw_->fields_data()) {
|
||||
int64_t field_id = field_data.field_id();
|
||||
auto info = FieldDataInfo(field_data.field_id(), N);
|
||||
auto field_meta = fields.at(FieldId(field_id));
|
||||
info.channel->push(
|
||||
CreateFieldDataFromDataArray(N, &field_data, field_meta));
|
||||
info.channel->close();
|
||||
segment1->LoadFieldData(FieldId(field_id), info);
|
||||
}
|
||||
prepareSegmentSystemFieldData(segment1, N, raw_data1);
|
||||
|
||||
//load segment2 raw data
|
||||
for (auto field_data : raw_data2.raw_->fields_data()) {
|
||||
int64_t field_id = field_data.field_id();
|
||||
auto info = FieldDataInfo(field_data.field_id(), N);
|
||||
auto field_meta = fields.at(FieldId(field_id));
|
||||
info.channel->push(
|
||||
CreateFieldDataFromDataArray(N, &field_data, field_meta));
|
||||
info.channel->close();
|
||||
segment2->LoadFieldData(FieldId(field_id), info);
|
||||
}
|
||||
prepareSegmentSystemFieldData(segment2, N, raw_data2);
|
||||
|
||||
//3. load index
|
||||
auto vector_data_1 = raw_data1.get_col<float>(vec_fid);
|
||||
auto indexing_1 = GenVecIndexing(N, dim, vector_data_1.data(), knowhere::IndexEnum::INDEX_HNSW);
|
||||
LoadIndexInfo load_index_info_1;
|
||||
load_index_info_1.field_id = vec_fid.get();
|
||||
load_index_info_1.index = std::move(indexing_1);
|
||||
load_index_info_1.index_params[METRICS_TYPE] = knowhere::metric::L2;
|
||||
segment1->LoadIndex(load_index_info_1);
|
||||
|
||||
auto vector_data_2 = raw_data2.get_col<float>(vec_fid);
|
||||
auto indexing_2 = GenVecIndexing(N, dim, vector_data_2.data(), knowhere::IndexEnum::INDEX_HNSW);
|
||||
LoadIndexInfo load_index_info_2;
|
||||
load_index_info_2.field_id = vec_fid.get();
|
||||
load_index_info_2.index = std::move(indexing_2);
|
||||
load_index_info_2.index_params[METRICS_TYPE] = knowhere::metric::L2;
|
||||
segment2->LoadIndex(load_index_info_2);
|
||||
|
||||
|
||||
//4. search group by respectively
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 100
|
||||
query_info: <
|
||||
topk: 100
|
||||
metric_type: "L2"
|
||||
search_params: "{\"ef\": 10}"
|
||||
group_by_field_id: 101
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 10;
|
||||
auto topK = 100;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
CPlaceholderGroup c_ph_group = ph_group.release();
|
||||
CSearchPlan c_plan = plan.release();
|
||||
|
||||
CSegmentInterface c_segment_1 = segment1.release();
|
||||
CSegmentInterface c_segment_2 = segment2.release();
|
||||
CSearchResult c_search_res_1;
|
||||
CSearchResult c_search_res_2;
|
||||
auto status = Search(c_segment_1, c_plan, c_ph_group, {}, &c_search_res_1);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
status = Search(c_segment_2, c_plan, c_ph_group, {}, &c_search_res_2);
|
||||
ASSERT_EQ(status.error_code, Success);
|
||||
std::vector<CSearchResult> results;
|
||||
results.push_back(c_search_res_1);
|
||||
results.push_back(c_search_res_2);
|
||||
|
||||
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
|
||||
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
|
||||
CSearchResultDataBlobs cSearchResultData;
|
||||
status = ReduceSearchResultsAndFillData(
|
||||
&cSearchResultData,
|
||||
c_plan,
|
||||
results.data(),
|
||||
results.size(),
|
||||
slice_nqs.data(),
|
||||
slice_topKs.data(),
|
||||
slice_nqs.size()
|
||||
);
|
||||
CheckSearchResultDuplicate(results);
|
||||
DeleteSearchResult(c_search_res_1);
|
||||
DeleteSearchResult(c_search_res_2);
|
||||
DeleteSearchResultDataBlobs(cSearchResultData);
|
||||
|
||||
|
||||
DeleteSearchPlan(c_plan);
|
||||
DeletePlaceholderGroup(c_ph_group);
|
||||
DeleteSegment(c_segment_1);
|
||||
DeleteSegment(c_segment_2);
|
||||
}
|
|
@ -400,7 +400,7 @@ TEST(Sealed, LoadFieldData) {
|
|||
|
||||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
// std::string dsl = R"({
|
||||
|
@ -525,7 +525,7 @@ TEST(Sealed, LoadFieldDataMmap) {
|
|||
|
||||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
|
@ -616,7 +616,7 @@ TEST(Sealed, LoadScalarIndex) {
|
|||
|
||||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
// std::string dsl = R"({
|
||||
|
@ -1135,7 +1135,7 @@ TEST(Sealed, GetVector) {
|
|||
|
||||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
|
||||
|
||||
auto segment_sealed = CreateSealedSegment(schema);
|
||||
|
||||
|
|
|
@ -904,7 +904,7 @@ SealedCreator(SchemaPtr schema, const GeneratedData& dataset) {
|
|||
}
|
||||
|
||||
inline std::unique_ptr<milvus::index::VectorIndex>
|
||||
GenVecIndexing(int64_t N, int64_t dim, const float* vec) {
|
||||
GenVecIndexing(int64_t N, int64_t dim, const float* vec, const char* index_type) {
|
||||
auto conf =
|
||||
knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
|
||||
{knowhere::meta::DIM, std::to_string(dim)},
|
||||
|
@ -920,7 +920,7 @@ GenVecIndexing(int64_t N, int64_t dim, const float* vec) {
|
|||
milvus::storage::FileManagerContext file_manager_context(
|
||||
field_data_meta, index_meta, chunk_manager);
|
||||
auto indexing = std::make_unique<index::VectorMemIndex<float>>(
|
||||
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
|
||||
index_type,
|
||||
knowhere::metric::L2,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
file_manager_context);
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/format.hpp>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "common/type_c.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "segcore/Collection.h"
|
||||
#include "segcore/Reduce.h"
|
||||
#include "segcore/reduce_c.h"
|
||||
#include "segcore/segment_c.h"
|
||||
#include "DataGen.h"
|
||||
#include "PbHelper.h"
|
||||
#include "c_api_test_utils.h"
|
||||
#include "indexbuilder_test_utils.h"
|
||||
|
||||
using namespace milvus;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
namespace {
|
||||
const char*
|
||||
get_default_schema_config() {
|
||||
static std::string conf = R"(name: "default-collection"
|
||||
fields: <
|
||||
fieldID: 100
|
||||
name: "fakevec"
|
||||
data_type: FloatVector
|
||||
type_params: <
|
||||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
fieldID: 101
|
||||
name: "age"
|
||||
data_type: Int64
|
||||
is_primary_key: true
|
||||
>)";
|
||||
static std::string fake_conf = "";
|
||||
return conf.c_str();
|
||||
}
|
||||
|
||||
std::string
|
||||
generate_max_float_query_data(int all_nq, int max_float_nq) {
|
||||
assert(max_float_nq <= all_nq);
|
||||
namespace ser = milvus::proto::common;
|
||||
int dim = DIM;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::FloatVector);
|
||||
for (int i = 0; i < all_nq; ++i) {
|
||||
std::vector<float> vec;
|
||||
if (i < max_float_nq) {
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(std::numeric_limits<float>::max());
|
||||
}
|
||||
} else {
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(1);
|
||||
}
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
auto blob = raw_group.SerializeAsString();
|
||||
return blob;
|
||||
}
|
||||
std::string
|
||||
generate_query_data(int nq) {
|
||||
namespace ser = milvus::proto::common;
|
||||
std::default_random_engine e(67);
|
||||
int dim = DIM;
|
||||
std::normal_distribution<double> dis(0.0, 1.0);
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::FloatVector);
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
std::vector<float> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(dis(e));
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
auto blob = raw_group.SerializeAsString();
|
||||
return blob;
|
||||
}
|
||||
void
|
||||
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
|
||||
auto nq = ((SearchResult*)results[0])->total_nq_;
|
||||
|
||||
std::unordered_set<PkType> pk_set;
|
||||
std::unordered_set<GroupByValueType> group_by_val_set;
|
||||
for (int qi = 0; qi < nq; qi++) {
|
||||
pk_set.clear();
|
||||
group_by_val_set.clear();
|
||||
for (size_t i = 0; i < results.size(); i++) {
|
||||
auto search_result = (SearchResult*)results[i];
|
||||
ASSERT_EQ(nq, search_result->total_nq_);
|
||||
auto topk_beg = search_result->topk_per_nq_prefix_sum_[qi];
|
||||
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
|
||||
for (size_t ki = topk_beg; ki < topk_end; ki++) {
|
||||
ASSERT_NE(search_result->seg_offsets_[ki], INVALID_SEG_OFFSET);
|
||||
auto ret = pk_set.insert(search_result->primary_keys_[ki]);
|
||||
ASSERT_TRUE(ret.second);
|
||||
|
||||
if(search_result->group_by_values_.size()>ki){
|
||||
auto group_by_val = search_result->group_by_values_[ki];
|
||||
ASSERT_TRUE(group_by_val_set.count(group_by_val)==0);
|
||||
group_by_val_set.insert(group_by_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -57,6 +57,7 @@ message QueryInfo {
|
|||
string metric_type = 3;
|
||||
string search_params = 4;
|
||||
int64 round_decimal = 5;
|
||||
int64 group_by_field_id = 6;
|
||||
}
|
||||
|
||||
message ColumnInfo {
|
||||
|
|
|
@ -44,6 +44,7 @@ import (
|
|||
const (
|
||||
IgnoreGrowingKey = "ignore_growing"
|
||||
ReduceStopForBestKey = "reduce_stop_for_best"
|
||||
GroupByFieldKey = "group_by_field"
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
NQKey = "nq"
|
||||
|
|
|
@ -116,7 +116,8 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
|
|||
}
|
||||
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInfo, int64, error) {
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*planpb.QueryInfo, int64, error) {
|
||||
// 1. parse offset and real topk
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, 0, errors.New(TopKKey + " not found in search_params")
|
||||
|
@ -149,11 +150,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn
|
|||
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
||||
}
|
||||
|
||||
// 2. parse metrics type
|
||||
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
metricType = ""
|
||||
}
|
||||
|
||||
// 3. parse round decimal
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
|
@ -167,15 +170,39 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn
|
|||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
// 4. parse search param str
|
||||
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
|
||||
if err != nil {
|
||||
searchParamStr = ""
|
||||
}
|
||||
|
||||
// 5. parse group by field
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
groupByFieldName = ""
|
||||
}
|
||||
var groupByFieldId int64
|
||||
if groupByFieldName != "" {
|
||||
groupByFieldId = -1
|
||||
fields := schema.GetFields()
|
||||
for _, field := range fields {
|
||||
if field.Name == groupByFieldName {
|
||||
groupByFieldId = field.FieldID
|
||||
break
|
||||
}
|
||||
}
|
||||
if groupByFieldId == -1 {
|
||||
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
||||
}
|
||||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParamStr,
|
||||
RoundDecimal: roundDecimal,
|
||||
GroupByFieldId: groupByFieldId,
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
|
@ -299,10 +326,14 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
|
||||
annsField = vecFields[0].Name
|
||||
}
|
||||
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams())
|
||||
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryInfo.GroupByFieldId != 0 {
|
||||
t.SearchRequest.IgnoreGrowing = true
|
||||
// for group by operation, currently, we ignore growing segments
|
||||
}
|
||||
t.offset = offset
|
||||
|
||||
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
|
||||
|
@ -405,7 +436,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
zap.Bool("use_default_consistency", useDefaultConsistency),
|
||||
zap.Any("consistency level", consistencyLevel),
|
||||
zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -817,7 +847,6 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
|||
default:
|
||||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
|
||||
for i, sData := range subSearchResultData {
|
||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
|
@ -852,6 +881,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
|||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var (
|
||||
|
@ -859,8 +889,9 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
|||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{})
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{})
|
||||
groupByValSet = make(map[interface{}]struct{})
|
||||
)
|
||||
|
||||
// skip offset results
|
||||
|
@ -882,17 +913,32 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
|||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
subSearchRes := subSearchResultData[subSearchIdx]
|
||||
|
||||
id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
|
||||
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||
score := subSearchRes.Scores[resultDataIdx]
|
||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
groupByValExist := false
|
||||
if groupByVal != nil {
|
||||
_, groupByValExist = groupByValSet[groupByVal]
|
||||
}
|
||||
if !groupByValExist {
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
if groupByVal != nil {
|
||||
groupByValSet[groupByVal] = struct{}{}
|
||||
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
}
|
||||
j++
|
||||
}
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
|
|
|
@ -1591,6 +1591,78 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
||||
var (
|
||||
nq int64 = 2
|
||||
topK int64 = 5
|
||||
)
|
||||
ids := [][]int64{
|
||||
{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
|
||||
{2, 4, 6, 8, 10, 2, 4, 6, 8, 10},
|
||||
}
|
||||
scores := [][]float32{
|
||||
{10, 8, 6, 4, 2, 10, 8, 6, 4, 2},
|
||||
{9, 7, 5, 3, 1, 9, 7, 5, 3, 1},
|
||||
}
|
||||
|
||||
groupByValuesArr := [][][]int64{
|
||||
{
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
}, // result2 has completely same group_by values, no result from result2 can be selected
|
||||
{
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
{6, 8, 3, 4, 5, 6, 8, 3, 4, 5},
|
||||
}, // result2 will contribute group_by values 6 and 8
|
||||
}
|
||||
expectedIDs := [][]int64{
|
||||
{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
}
|
||||
expectedScores := [][]float32{
|
||||
{-10, -8, -6, -4, -2, -10, -8, -6, -4, -2},
|
||||
{-10, -9, -8, -7, -6, -10, -9, -8, -7, -6},
|
||||
}
|
||||
expectedGroupByValues := [][]int64{
|
||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||
{1, 6, 2, 8, 3, 1, 6, 2, 8, 3},
|
||||
}
|
||||
|
||||
for i, groupByValues := range groupByValuesArr {
|
||||
t.Run("Group By correctness", func(t *testing.T) {
|
||||
var results []*schemapb.SearchResultData
|
||||
for j := range ids {
|
||||
result := getSearchResultData(nq, topK)
|
||||
result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}}
|
||||
result.Scores = scores[j]
|
||||
result.Topks = []int64{topK, topK}
|
||||
result.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: groupByValues[j],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topK, metric.L2, schemapb.DataType_Int64, 0)
|
||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||
resultScores := reduced.GetResults().GetScores()
|
||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||
assert.EqualValues(t, expectedIDs[i], resultIDs)
|
||||
assert.EqualValues(t, expectedScores[i], resultScores)
|
||||
assert.EqualValues(t, expectedGroupByValues[i], resultGroupByValues)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
var (
|
||||
err error
|
||||
|
@ -1784,7 +1856,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.validParams)
|
||||
info, offset, err := parseSearchInfo(test.validParams, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
if test.description == "offsetParam" {
|
||||
|
@ -1873,7 +1945,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, offset, err := parseSearchInfo(test.invalidParams)
|
||||
info, offset, err := parseSearchInfo(test.invalidParams, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
assert.Zero(t, offset)
|
||||
|
|
|
@ -129,6 +129,7 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
offsets := make([]int64, len(searchResultData))
|
||||
|
||||
idSet := make(map[interface{}]struct{})
|
||||
groupByValueSet := make(map[interface{}]struct{})
|
||||
var j int64
|
||||
for j = 0; j < topk; {
|
||||
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
||||
|
@ -138,15 +139,29 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
idx := resultOffsets[sel][i] + offsets[sel]
|
||||
|
||||
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
||||
groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx))
|
||||
score := searchResultData[sel].Scores[idx]
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
groupByValExist := false
|
||||
if groupByVal != nil {
|
||||
_, groupByValExist = groupByValueSet[groupByVal]
|
||||
}
|
||||
if !groupByValExist {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
if groupByVal != nil {
|
||||
groupByValueSet[groupByVal] = struct{}{}
|
||||
if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil {
|
||||
log.Error("Failed to append groupByValues", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
}
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
}
|
||||
} else {
|
||||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
|
|
|
@ -603,6 +603,139 @@ func (suite *ResultSuite) TestResult_ReduceSearchResultData() {
|
|||
})
|
||||
}
|
||||
|
||||
func (suite *ResultSuite) TestResult_SearchGroupByResult() {
|
||||
const (
|
||||
nq = 1
|
||||
topk = 4
|
||||
)
|
||||
suite.Run("reduce_group_by_int", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
|
||||
suite.ElementsMatch([]int32{2, 3, 4, 5}, res.GroupByFieldValue.GetScalars().GetIntData().Data)
|
||||
})
|
||||
suite.Run("reduce_group_by_bool", func() {
|
||||
ids1 := []int64{1, 2}
|
||||
scores1 := []float32{-1.0, -2.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{3, 4}
|
||||
scores2 := []float32{-1.0, -1.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -1.0}, res.Scores)
|
||||
suite.ElementsMatch([]bool{true, false}, res.GroupByFieldValue.GetScalars().GetBoolData().Data)
|
||||
})
|
||||
suite.Run("reduce_group_by_string", func() {
|
||||
ids1 := []int64{1, 2, 3, 4}
|
||||
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
|
||||
topks1 := []int64{int64(len(ids1))}
|
||||
ids2 := []int64{5, 1, 3, 4}
|
||||
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
|
||||
topks2 := []int64{int64(len(ids2))}
|
||||
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
|
||||
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
|
||||
data1.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data2.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"1", "2", "3", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||
dataArray = append(dataArray, data1)
|
||||
dataArray = append(dataArray, data2)
|
||||
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
|
||||
suite.Nil(err)
|
||||
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
|
||||
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
|
||||
suite.ElementsMatch([]string{"1", "2", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *ResultSuite) TestResult_SelectSearchResultData_int() {
|
||||
type args struct {
|
||||
dataArray []*schemapb.SearchResultData
|
||||
|
|
|
@ -1084,3 +1084,61 @@ func SelectMinPK[T ResultWithID](results []T, cursors []int64) (int, bool) {
|
|||
|
||||
return sel, drainResult
|
||||
}
|
||||
|
||||
func AppendGroupByValue(dstResData *schemapb.SearchResultData,
|
||||
groupByVal interface{}, srcDataType schemapb.DataType,
|
||||
) error {
|
||||
if dstResData.GroupByFieldValue == nil {
|
||||
dstResData.GroupByFieldValue = &schemapb.FieldData{
|
||||
Type: srcDataType,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{},
|
||||
},
|
||||
}
|
||||
}
|
||||
dstScalarField := dstResData.GroupByFieldValue.GetScalars()
|
||||
switch srcDataType {
|
||||
case schemapb.DataType_Bool:
|
||||
if dstScalarField.GetBoolData() == nil {
|
||||
dstScalarField.Data = &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{},
|
||||
},
|
||||
}
|
||||
}
|
||||
dstScalarField.GetBoolData().Data = append(dstScalarField.GetBoolData().Data, groupByVal.(bool))
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
if dstScalarField.GetIntData() == nil {
|
||||
dstScalarField.Data = &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{},
|
||||
},
|
||||
}
|
||||
}
|
||||
dstScalarField.GetIntData().Data = append(dstScalarField.GetIntData().Data, groupByVal.(int32))
|
||||
case schemapb.DataType_Int64:
|
||||
if dstScalarField.GetLongData() == nil {
|
||||
dstScalarField.Data = &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{},
|
||||
},
|
||||
}
|
||||
}
|
||||
dstScalarField.GetLongData().Data = append(dstScalarField.GetLongData().Data, groupByVal.(int64))
|
||||
case schemapb.DataType_VarChar:
|
||||
if dstScalarField.GetStringData() == nil {
|
||||
dstScalarField.Data = &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{},
|
||||
},
|
||||
}
|
||||
}
|
||||
dstScalarField.GetStringData().Data = append(dstScalarField.GetStringData().Data, groupByVal.(string))
|
||||
default:
|
||||
log.Error("Not supported field type from group_by value field", zap.String("field type",
|
||||
srcDataType.String()))
|
||||
return fmt.Errorf("not supported field type from group_by value field: %s",
|
||||
srcDataType.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue