feat: support search_group_by for milvus(#25324) (#28983)

related: #25324

Search GroupBy function, used to aggregate result entities based on a
specific scalar column.
several points to mention:

1. Temporarliy, the whole groupby is implemented separated from
iterative expr framework **for the first period**
2. In the long term, the groupBy operation will be incorporated into the
iterative expr framework:https://github.com/milvus-io/milvus/pull/28166
3. This pr includes some unrelated mocked interface regarding alterIndex
due to some unworth-to-mention reasons. All these un-associated content
will be removed before the final pr is merged. This version of pr is
only for review
4. All other related details were commented in the files comparison

Signed-off-by: MrPresent-Han <chun.han@zilliz.com>
pull/29933/head
MrPresent-Han 2024-01-05 15:50:47 +08:00 committed by GitHub
parent 22bb84fa9d
commit 9e2e7157e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1530 additions and 54 deletions

View File

@ -21,12 +21,14 @@
#include "common/Types.h"
#include "knowhere/config.h"
namespace milvus {
struct SearchInfo {
int64_t topk_;
int64_t round_decimal_;
FieldId field_id_;
MetricType metric_type_;
knowhere::Json search_params_;
std::optional<FieldId> group_by_field_id_;
};
using SearchInfoPtr = std::shared_ptr<SearchInfo>;

View File

@ -28,6 +28,7 @@
#include "common/FieldMeta.h"
#include "pb/schema.pb.h"
#include "knowhere/index_node.h"
namespace milvus {
struct SearchResult {
@ -52,6 +53,7 @@ struct SearchResult {
// first fill data during search, and then update data after reducing search results
std::vector<float> distances_;
std::vector<int64_t> seg_offsets_;
std::vector<GroupByValueType> group_by_values_;
// first fill data during fillPrimaryKey, and then update data after reducing search results
std::vector<PkType> primary_keys_;
@ -67,6 +69,10 @@ struct SearchResult {
// used for reduce, filter invalid pk, get real topks count
std::vector<size_t> topk_per_nq_prefix_sum_;
//knowhere iterators, used for group by or other operators in the future
std::optional<std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>>
iterators;
};
using SearchResultPtr = std::shared_ptr<SearchResult>;

View File

@ -116,6 +116,13 @@ using VectorArray = proto::schema::VectorField;
using IdArray = proto::schema::IDs;
using InsertData = proto::segcore::InsertRecord;
using PkType = std::variant<std::monostate, int64_t, std::string>;
using GroupByValueType = std::variant<std::monostate,
int8_t,
int16_t,
int32_t,
int64_t,
bool,
std::string_view>;
using ContainsType = proto::plan::JSONContainsExpr_JSONOp;
inline bool

View File

@ -522,6 +522,35 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
auto num_queries = dataset->GetRows();
knowhere::Json search_conf = search_info.search_params_;
if (search_info.group_by_field_id_.has_value()) {
auto result = std::make_unique<SearchResult>();
try {
knowhere::expected<
std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>>
iterators_val =
index_.AnnIterator(*dataset, search_conf, bitset);
if (iterators_val.has_value()) {
result->iterators = iterators_val.value();
} else {
LOG_ERROR(
"Returned knowhere iterator has non-ready iterators "
"inside, terminate group_by operation");
PanicInfo(ErrorCode::Unsupported,
"Returned knowhere iterator has non-ready iterators "
"inside, terminate group_by operation");
}
} catch (const std::runtime_error& e) {
LOG_ERROR(
"Caught error:{} when trying to initialize ann iterators for "
"group_by: "
"group_by operation will be terminated",
e.what());
throw e;
}
return result;
//if the target index doesn't support iterators, directly return empty search result
//and the reduce process to filter empty results
}
auto topk = search_info.topk_;
// TODO :: check dim of search data
auto final = [&] {

View File

@ -26,6 +26,7 @@ set(MILVUS_QUERY_SRCS
SearchOnIndex.cpp
SearchBruteForce.cpp
SubSearchResult.cpp
GroupByOperator.cpp
PlanProto.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})

View File

@ -0,0 +1,211 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "GroupByOperator.h"
#include "common/Consts.h"
#include "segcore/SegmentSealedImpl.h"
#include "Utils.h"
namespace milvus {
namespace query {
void
GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
iterators,
const SearchInfo& search_info,
std::vector<GroupByValueType>& group_by_values,
const segcore::SegmentInternalInterface& segment,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances) {
//0. check segment type, for period-1, only support group by for sealed segments
if (!dynamic_cast<const segcore::SegmentSealedImpl*>(&segment)) {
LOG_ERROR(
"Not support group_by operation for non-sealed segment, "
"segment_id:{}",
segment.get_segment_id());
return;
}
//1. get search meta
FieldId group_by_field_id = search_info.group_by_field_id_.value();
auto data_type = segment.GetFieldDataType(group_by_field_id);
switch (data_type) {
case DataType::INT8: {
auto field_data = segment.chunk_data<int8_t>(group_by_field_id, 0);
GroupIteratorsByType<int8_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
break;
}
case DataType::INT16: {
auto field_data = segment.chunk_data<int16_t>(group_by_field_id, 0);
GroupIteratorsByType<int16_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
break;
}
case DataType::INT32: {
auto field_data = segment.chunk_data<int32_t>(group_by_field_id, 0);
GroupIteratorsByType<int32_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
break;
}
case DataType::INT64: {
auto field_data = segment.chunk_data<int64_t>(group_by_field_id, 0);
GroupIteratorsByType<int64_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
break;
}
case DataType::BOOL: {
auto field_data = segment.chunk_data<bool>(group_by_field_id, 0);
GroupIteratorsByType<bool>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
break;
}
case DataType::VARCHAR: {
auto field_data =
segment.chunk_data<std::string_view>(group_by_field_id, 0);
GroupIteratorsByType<std::string_view>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported data type {} for group by operator",
data_type));
}
}
}
template <typename T>
void
GroupIteratorsByType(
const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
iterators,
FieldId field_id,
int64_t topK,
Span<T> field_data,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type) {
for (auto& iterator : iterators) {
GroupIteratorResult<T>(iterator,
field_id,
topK,
field_data,
group_by_values,
seg_offsets,
distances,
metrics_type);
}
}
template <typename T>
void
GroupIteratorResult(
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
FieldId field_id,
int64_t topK,
Span<T> field_data,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type) {
//1.
std::unordered_map<T, std::pair<int64_t, float>> groupMap;
//2. do iteration until fill the whole map or run out of all data
//note it may enumerate all data inside a segment and can block following
//query and search possibly
auto dis_closer = [&](float l, float r) {
if (PositivelyRelated(metrics_type))
return l > r;
return l <= r;
};
while (iterator->HasNext() && groupMap.size() < topK) {
auto [offset, dis] = iterator->Next();
const T& row_data = field_data.operator[](offset);
auto it = groupMap.find(row_data);
if (it == groupMap.end()) {
groupMap.insert(
std::make_pair(row_data, std::make_pair(offset, dis)));
} else if (dis_closer(dis, it->second.second)) {
it->second = {offset, dis};
}
}
//3. sorted based on distances and metrics
std::vector<std::pair<T, std::pair<int64_t, float>>> sortedGroupVals(
groupMap.begin(), groupMap.end());
auto customComparator = [&](const auto& lhs, const auto& rhs) {
return dis_closer(lhs.second.second, rhs.second.second);
};
std::sort(sortedGroupVals.begin(), sortedGroupVals.end(), customComparator);
//4. save groupBy results
for (auto iter = sortedGroupVals.cbegin(); iter != sortedGroupVals.cend();
iter++) {
group_by_values.emplace_back(iter->first);
offsets.push_back(iter->second.first);
distances.push_back(iter->second.second);
}
//5. padding topK results, extra memory consumed will be removed when reducing
for (std::size_t idx = groupMap.size(); idx < topK; idx++) {
offsets.push_back(INVALID_SEG_OFFSET);
distances.push_back(0.0);
group_by_values.emplace_back(std::monostate{});
}
}
} // namespace query
} // namespace milvus

View File

@ -0,0 +1,61 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common/QueryInfo.h"
#include "knowhere/index_node.h"
#include "segcore/SegmentInterface.h"
#include "common/Span.h"
namespace milvus {
namespace query {
void
GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
iterators,
const SearchInfo& searchInfo,
std::vector<GroupByValueType>& group_by_values,
const segcore::SegmentInternalInterface& segment,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances);
template <typename T>
void
GroupIteratorsByType(
const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
iterators,
FieldId field_id,
int64_t topK,
Span<T> field_data,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type);
template <typename T>
void
GroupIteratorResult(
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
FieldId field_id,
int64_t topK,
Span<T> field_data,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type);
} // namespace query
} // namespace milvus

View File

@ -203,6 +203,10 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.search_params_ =
nlohmann::json::parse(query_info_proto.search_params());
if (query_info_proto.group_by_field_id() != 0) {
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
search_info.group_by_field_id_ = group_by_field_id;
}
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::BinaryVector) {

View File

@ -49,18 +49,22 @@ SearchOnSealedIndex(const Schema& schema,
auto index_type = vec_index->GetIndexType();
return vec_index->Query(ds, search_info, bitset);
}();
if (final->iterators.has_value()) {
result.iterators = std::move(final->iterators);
} else {
float* distances = final->distances_.data();
float* distances = final->distances_.data();
auto total_num = num_queries * topk;
if (round_decimal != -1) {
const float multiplier = pow(10.0, round_decimal);
for (int i = 0; i < total_num; i++) {
distances[i] = std::round(distances[i] * multiplier) / multiplier;
auto total_num = num_queries * topk;
if (round_decimal != -1) {
const float multiplier = pow(10.0, round_decimal);
for (int i = 0; i < total_num; i++) {
distances[i] =
std::round(distances[i] * multiplier) / multiplier;
}
}
result.seg_offsets_ = std::move(final->seg_offsets_);
result.distances_ = std::move(final->distances_);
}
result.seg_offsets_ = std::move(final->seg_offsets_);
result.distances_ = std::move(final->distances_);
result.total_nq_ = num_queries;
result.unity_topK_ = topk;
}

View File

@ -22,6 +22,8 @@
#include "log/Log.h"
#include "plan/PlanNode.h"
#include "exec/Task.h"
#include "segcore/SegmentInterface.h"
#include "query/GroupByOperator.h"
namespace milvus::query {
@ -188,7 +190,20 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
timestamp_,
final_view,
search_result);
if (search_result.iterators.has_value()) {
GroupBy(search_result.iterators.value(),
node.search_info_,
search_result.group_by_values_,
*segment,
search_result.seg_offsets_,
search_result.distances_);
AssertInfo(search_result.seg_offsets_.size() ==
search_result.group_by_values_.size(),
"Wrong state! search_result group_by_values_ size:{} is not "
"equal to search_result.seg_offsets.size:{}",
search_result.group_by_values_.size(),
search_result.seg_offsets_.size());
}
search_result_opt_ = std::move(search_result);
}

View File

@ -12,8 +12,6 @@
#include "Reduce.h"
#include <log/Log.h>
#include <algorithm>
#include <cstdint>
#include <vector>
@ -90,6 +88,15 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
auto& offsets = search_result->seg_offsets_;
auto& distances = search_result->distances_;
bool need_filter_group_by = !search_result->group_by_values_.empty();
if (need_filter_group_by) {
AssertInfo(search_result->distances_.size() ==
search_result->group_by_values_.size(),
"wrong group_by_values size, size:{}, expected size:{} ",
search_result->group_by_values_.size(),
search_result->distances_.size());
}
for (auto i = 0; i < nq; ++i) {
for (auto j = 0; j < topK; ++j) {
auto index = i * topK + j;
@ -104,12 +111,17 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
real_topks[i]++;
offsets[valid_index] = offsets[index];
distances[valid_index] = distances[index];
if (need_filter_group_by)
search_result->group_by_values_[valid_index] =
search_result->group_by_values_[index];
valid_index++;
}
}
}
offsets.resize(valid_index);
distances.resize(valid_index);
if (need_filter_group_by)
search_result->group_by_values_.resize(valid_index);
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
std::partial_sum(real_topks.begin(),
@ -130,9 +142,8 @@ ReduceHelper::FillPrimaryKey() {
FilterInvalidSearchResult(search_result);
LOG_DEBUG("the size of search result: {}",
search_result->seg_offsets_.size());
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
if (search_result->get_total_result_count() > 0) {
auto segment =
static_cast<SegmentInterface*>(search_result->segment_);
segment->FillPrimaryKeys(plan_, *search_result);
search_results_[valid_index++] = search_result;
}
@ -146,6 +157,15 @@ ReduceHelper::RefreshSearchResult() {
for (int i = 0; i < num_segments_; i++) {
std::vector<int64_t> real_topks(total_nq_, 0);
auto search_result = search_results_[i];
bool need_to_handle_group_by = !search_result->group_by_values_.empty();
if (need_to_handle_group_by) {
AssertInfo(search_result->primary_keys_.size() ==
search_result->group_by_values_.size(),
"Wrong size for group_by_values size:{}, not equal to "
"primary_keys_.size:{}",
search_result->group_by_values_.size(),
search_result->primary_keys_.size());
}
if (search_result->result_offsets_.size() != 0) {
uint32_t size = 0;
for (int j = 0; j < total_nq_; j++) {
@ -154,6 +174,7 @@ ReduceHelper::RefreshSearchResult() {
std::vector<milvus::PkType> primary_keys(size);
std::vector<float> distances(size);
std::vector<int64_t> seg_offsets(size);
std::vector<GroupByValueType> group_by_values(size);
uint32_t index = 0;
for (int j = 0; j < total_nq_; j++) {
@ -161,6 +182,9 @@ ReduceHelper::RefreshSearchResult() {
primary_keys[index] = search_result->primary_keys_[offset];
distances[index] = search_result->distances_[offset];
seg_offsets[index] = search_result->seg_offsets_[offset];
if (need_to_handle_group_by)
group_by_values[index] =
search_result->group_by_values_[offset];
index++;
real_topks[j]++;
}
@ -168,6 +192,9 @@ ReduceHelper::RefreshSearchResult() {
search_result->primary_keys_.swap(primary_keys);
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
if (need_to_handle_group_by) {
search_result->group_by_values_.swap(group_by_values);
}
}
std::partial_sum(real_topks.begin(),
real_topks.end(),
@ -193,8 +220,14 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
}
pk_set_.clear();
pairs_.clear();
group_by_val_set_.clear();
pairs_.reserve(num_segments_);
bool need_handle_group_by_values = false;
if (num_segments_ > 0) {
need_handle_group_by_values =
!search_results_[0]->group_by_values_.empty();
}
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
@ -206,7 +239,16 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
auto distance = search_result->distances_[offset_beg];
pairs_.emplace_back(
primary_key, distance, search_result, i, offset_beg, offset_end);
primary_key,
distance,
search_result,
i,
offset_beg,
offset_end,
need_handle_group_by_values
? std::make_optional(
search_result->group_by_values_.at(offset_beg))
: std::nullopt);
heap_.push(&pairs_.back());
}
@ -229,9 +271,22 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
}
// remove duplicates
if (pk_set_.count(pk) == 0) {
pilot->search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot->offset_);
pk_set_.insert(pk);
bool skip_for_group_by = false;
if (need_handle_group_by_values &&
pilot->group_by_value_.has_value()) {
if (group_by_val_set_.count(pilot->group_by_value_.value()) >
0) {
skip_for_group_by = true;
}
}
if (!skip_for_group_by) {
pilot->search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot->offset_);
pk_set_.insert(pk);
if (need_handle_group_by_values &&
pilot->group_by_value_.has_value())
group_by_val_set_.insert(pilot->group_by_value_.value());
}
} else {
// skip entity with same primary key
dup_cnt++;
@ -331,6 +386,10 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
// reserve space for distances
search_result_data->mutable_scores()->Resize(result_count, 0);
//reserve space for group_by_values
std::vector<GroupByValueType> group_by_values;
group_by_values.resize(result_count);
// fill pks and distances
for (auto qi = nq_begin; qi < nq_end; qi++) {
int64_t topk_count = 0;
@ -377,9 +436,11 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
}
}
// set result distances
search_result_data->mutable_scores()->Set(
loc, search_result->distances_[ki]);
// set group by values
if (ki < search_result->group_by_values_.size())
group_by_values[loc] = search_result->group_by_values_[ki];
// set result offset to fill output fields data
result_pairs[loc] = std::make_pair(search_result, ki);
}
@ -388,6 +449,7 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
// update result topKs
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
}
AssembleGroupByValues(search_result_data, group_by_values);
AssertInfo(search_result_data->scores_size() == result_count,
"wrong scores size, size = " +
@ -417,4 +479,89 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
return buffer;
}
void
ReduceHelper::AssembleGroupByValues(
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
const std::vector<GroupByValueType>& group_by_vals) {
auto group_by_field_id = plan_->plan_node_->search_info_.group_by_field_id_;
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
auto group_by_values_field =
std::make_unique<milvus::proto::schema::ScalarField>();
auto group_by_field =
plan_->schema_.operator[](group_by_field_id.value());
DataType group_by_data_type = group_by_field.get_data_type();
int group_by_val_size = group_by_vals.size();
switch (group_by_data_type) {
case DataType::INT8: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int8_t val = std::get<int8_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT16: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int16_t val = std::get<int16_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT32: {
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int32_t val = std::get<int32_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::INT64: {
auto field_data = group_by_values_field->mutable_long_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int64_t val = std::get<int64_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::BOOL: {
auto field_data = group_by_values_field->mutable_bool_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
bool val = std::get<bool>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
}
break;
}
case DataType::VARCHAR: {
auto field_data = group_by_values_field->mutable_string_data();
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
std::string_view val =
std::get<std::string_view>(group_by_vals[idx]);
*(field_data->mutable_data()->Add()) = val;
}
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported datatype for group_by operations ",
group_by_data_type));
}
}
search_result->mutable_group_by_field_value()->set_type(
milvus::proto::schema::DataType(group_by_data_type));
search_result->mutable_group_by_field_value()
->mutable_scalars()
->MergeFrom(*group_by_values_field.get());
return;
}
}
} // namespace milvus::segcore

View File

@ -82,6 +82,11 @@ class ReduceHelper {
std::vector<char>
GetSearchResultDataSlice(int slice_index_);
void
AssembleGroupByValues(
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_result,
const std::vector<GroupByValueType>& group_by_vals);
private:
std::vector<SearchResult*>& search_results_;
milvus::query::Plan* plan_;
@ -108,6 +113,7 @@ class ReduceHelper {
SearchResultPairComparator>
heap_;
std::unordered_set<milvus::PkType> pk_set_;
std::unordered_set<milvus::GroupByValueType> group_by_val_set_;
};
} // namespace milvus::segcore

View File

@ -26,7 +26,8 @@ struct SearchResultPair {
milvus::SearchResult* search_result_;
int64_t segment_index_;
int64_t offset_;
int64_t offset_rb_; // right bound
int64_t offset_rb_; // right bound
std::optional<milvus::GroupByValueType> group_by_value_; //for group_by
SearchResultPair(milvus::PkType primary_key,
float distance,
@ -34,12 +35,24 @@ struct SearchResultPair {
int64_t index,
int64_t lb,
int64_t rb)
: SearchResultPair(
primary_key, distance, result, index, lb, rb, std::nullopt) {
}
SearchResultPair(milvus::PkType primary_key,
float distance,
SearchResult* result,
int64_t index,
int64_t lb,
int64_t rb,
std::optional<milvus::GroupByValueType> group_by_value)
: primary_key_(std::move(primary_key)),
distance_(distance),
search_result_(result),
segment_index_(index),
offset_(lb),
offset_rb_(rb) {
offset_rb_(rb),
group_by_value_(group_by_value) {
}
bool
@ -56,6 +69,9 @@ struct SearchResultPair {
if (offset_ < offset_rb_) {
primary_key_ = search_result_->primary_keys_.at(offset_);
distance_ = search_result_->distances_.at(offset_);
if (offset_ < search_result_->group_by_values_.size()) {
group_by_value_ = search_result_->group_by_values_.at(offset_);
}
} else {
primary_key_ = INVALID_PK;
distance_ = std::numeric_limits<float>::min();

View File

@ -418,6 +418,12 @@ SegmentGrowingImpl::num_chunk() const {
return upper_div(size, segcore_config_.get_chunk_rows());
}
DataType
SegmentGrowingImpl::GetFieldDataType(milvus::FieldId field_id) const {
auto& field_meta = schema_->operator[](field_id);
return field_meta.get_data_type();
}
void
SegmentGrowingImpl::vector_search(SearchInfo& search_info,
const void* query_data,

View File

@ -224,6 +224,9 @@ class SegmentGrowingImpl : public SegmentGrowing {
const BitsetView& bitset,
SearchResult& output) const override;
DataType
GetFieldDataType(FieldId fieldId) const override;
public:
void
mask_with_delete(BitsetType& bitset,

View File

@ -186,6 +186,9 @@ class SegmentInternalInterface : public SegmentInterface {
int64_t chunk_id,
const milvus::VariableColumn<std::string>& var_column);
virtual DataType
GetFieldDataType(FieldId fieldId) const = 0;
public:
virtual void
vector_search(SearchInfo& search_info,

View File

@ -1301,6 +1301,12 @@ SegmentSealedImpl::HasRawData(int64_t field_id) const {
return true;
}
DataType
SegmentSealedImpl::GetFieldDataType(milvus::FieldId field_id) const {
auto& field_meta = schema_->operator[](field_id);
return field_meta.get_data_type();
}
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
SegmentSealedImpl::search_ids(const IdArray& id_array,
Timestamp timestamp) const {

View File

@ -86,6 +86,9 @@ class SegmentSealedImpl : public SegmentSealed {
bool
HasRawData(int64_t field_id) const override;
DataType
GetFieldDataType(FieldId fieldId) const override;
public:
int64_t
GetMemoryUsageInBytes() const override;

View File

@ -61,6 +61,7 @@ set(MILVUS_TEST_FILES
test_storage.cpp
test_exec.cpp
test_inverted_index.cpp
test_group_by.cpp
)
if ( BUILD_DISK_ANN STREQUAL "ON" )

View File

@ -112,13 +112,13 @@ Search_Sealed(benchmark::State& state) {
if (choice == 0) {
// Brute Force
} else if (choice == 1) {
// ivf
// hnsw
auto vec = dataset_.get_col<float>(milvus::FieldId(100));
auto indexing = GenVecIndexing(N, dim, vec.data());
auto indexing = GenVecIndexing(N, dim, vec.data(), knowhere::IndexEnum::INDEX_HNSW);
segcore::LoadIndexInfo info;
info.index = std::move(indexing);
info.field_id = (*schema)[FieldName("fakevec")].get_id().get();
info.index_params["index_type"] = "IVF";
info.index_params["index_type"] = "HNSW";
info.index_params["metric_type"] = knowhere::metric::L2;
segment->DropFieldData(milvus::FieldId(100));
segment->LoadIndex(info);

View File

@ -0,0 +1,480 @@
//
// Created by zilliz on 2023/12/1.
//
#include <gtest/gtest.h>
#include "common/Schema.h"
#include "segcore/SegmentSealedImpl.h"
#include "test_utils/DataGen.h"
#include "query/Plan.h"
#include "segcore/segment_c.h"
#include "segcore/reduce_c.h"
#include "test_utils/c_api_test_utils.h"
#include "segcore/plan_c.h"
using namespace milvus;
using namespace milvus::segcore;
using namespace milvus::query;
using namespace milvus::storage;
const char* METRICS_TYPE = "metric_type";
void
prepareSegmentSystemFieldData(const std::unique_ptr<SegmentSealed>& segment,
size_t row_count,
GeneratedData& data_set){
auto field_data =
std::make_shared<milvus::FieldData<int64_t>>(DataType::INT64);
field_data->FillFieldData(data_set.row_ids_.data(), row_count);
auto field_data_info = FieldDataInfo{
RowFieldID.get(), row_count, std::vector<milvus::FieldDataPtr>{field_data}};
segment->LoadFieldData(RowFieldID, field_data_info);
field_data =
std::make_shared<milvus::FieldData<int64_t>>(DataType::INT64);
field_data->FillFieldData(data_set.timestamps_.data(), row_count);
field_data_info =
FieldDataInfo{TimestampFieldID.get(),
row_count,
std::vector<milvus::FieldDataPtr>{field_data}};
segment->LoadFieldData(TimestampFieldID, field_data_info);
}
TEST(GroupBY, Normal2){
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
//0. prepare schema
int dim = 64;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
auto str_fid = schema->AddDebugField("string1", DataType::VARCHAR);
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
schema->set_primary_field_id(str_fid);
auto segment = CreateSealedSegment(schema);
size_t N = 100;
//2. load raw data
auto raw_data = DataGen(schema, N);
auto fields = schema->get_fields();
for (auto field_data : raw_data.raw_->fields_data()) {
int64_t field_id = field_data.field_id();
auto info = FieldDataInfo(field_data.field_id(), N);
auto field_meta = fields.at(FieldId(field_id));
info.channel->push(
CreateFieldDataFromDataArray(N, &field_data, field_meta));
info.channel->close();
segment->LoadFieldData(FieldId(field_id), info);
}
prepareSegmentSystemFieldData(segment, N, raw_data);
//3. load index
auto vector_data = raw_data.get_col<float>(vec_fid);
auto indexing = GenVecIndexing(N, dim, vector_data.data(), knowhere::IndexEnum::INDEX_HNSW);
LoadIndexInfo load_index_info;
load_index_info.field_id = vec_fid.get();
load_index_info.index = std::move(indexing);
load_index_info.index_params[METRICS_TYPE] = knowhere::metric::L2;
segment->LoadIndex(load_index_info);
//4. search group by int8
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 100
metric_type: "L2"
search_params: "{\"ef\": 10}"
group_by_field_id: 101
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result = segment->Search(plan.get(), ph_group.get());
auto& group_by_values = search_result->group_by_values_;
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
int size = group_by_values.size();
std::unordered_set<int8_t> i8_set;
float lastDistance = 0.0;
for(size_t i = 0; i < size; i++){
if(std::holds_alternative<int8_t>(group_by_values[i])){
int8_t g_val = std::get<int8_t>(group_by_values[i]);
ASSERT_FALSE(i8_set.count(g_val)>0);//no repetition on groupBy field
i8_set.insert(g_val);
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
}
}
}
//4. search group by int16
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 100
metric_type: "L2"
search_params: "{\"ef\": 10}"
group_by_field_id: 102
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result = segment->Search(plan.get(), ph_group.get());
auto& group_by_values = search_result->group_by_values_;
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
int size = group_by_values.size();
std::unordered_set<int16_t> i16_set;
float lastDistance = 0.0;
for(size_t i = 0; i < size; i++){
if(std::holds_alternative<int16_t>(group_by_values[i])){
int16_t g_val = std::get<int16_t>(group_by_values[i]);
ASSERT_FALSE(i16_set.count(g_val)>0);//no repetition on groupBy field
i16_set.insert(g_val);
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
}
}
}
//4. search group by int32
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 100
metric_type: "L2"
search_params: "{\"ef\": 10}"
group_by_field_id: 103
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result = segment->Search(plan.get(), ph_group.get());
auto& group_by_values = search_result->group_by_values_;
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
int size = group_by_values.size();
std::unordered_set<int32_t> i32_set;
float lastDistance = 0.0;
for(size_t i = 0; i < size; i++){
if(std::holds_alternative<int32_t>(group_by_values[i])){
int16_t g_val = std::get<int32_t>(group_by_values[i]);
ASSERT_FALSE(i32_set.count(g_val)>0);//no repetition on groupBy field
i32_set.insert(g_val);
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
}
}
}
//4. search group by int64
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 100
metric_type: "L2"
search_params: "{\"ef\": 10}"
group_by_field_id: 104
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result = segment->Search(plan.get(), ph_group.get());
auto& group_by_values = search_result->group_by_values_;
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
int size = group_by_values.size();
std::unordered_set<int64_t> i64_set;
float lastDistance = 0.0;
for(size_t i = 0; i < size; i++){
if(std::holds_alternative<int64_t>(group_by_values[i])){
int16_t g_val = std::get<int64_t>(group_by_values[i]);
ASSERT_FALSE(i64_set.count(g_val)>0);//no repetition on groupBy field
i64_set.insert(g_val);
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
}
}
}
//4. search group by string
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 100
metric_type: "L2"
search_params: "{\"ef\": 10}"
group_by_field_id: 105
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result = segment->Search(plan.get(), ph_group.get());
auto& group_by_values = search_result->group_by_values_;
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
int size = group_by_values.size();
std::unordered_set<std::string_view> strs_set;
float lastDistance = 0.0;
for(size_t i = 0; i < size; i++){
if(std::holds_alternative<std::string_view>(group_by_values[i])){
std::string_view g_val = std::get<std::string_view>(group_by_values[i]);
ASSERT_FALSE(strs_set.count(g_val)>0);//no repetition on groupBy field
strs_set.insert(g_val);
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
}
}
}
//4. search group by bool
{
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 100
metric_type: "L2"
search_params: "{\"ef\": 10}"
group_by_field_id: 106
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result = segment->Search(plan.get(), ph_group.get());
auto& group_by_values = search_result->group_by_values_;
ASSERT_EQ(search_result->group_by_values_.size(), search_result->seg_offsets_.size());
ASSERT_EQ(search_result->distances_.size(), search_result->seg_offsets_.size());
int size = group_by_values.size();
std::unordered_set<bool> bools_set;
int boolValCount = 0;
float lastDistance = 0.0;
for(size_t i = 0; i < size; i++){
if(std::holds_alternative<bool>(group_by_values[i])){
bool g_val = std::get<bool>(group_by_values[i]);
ASSERT_FALSE(bools_set.count(g_val)>0);//no repetition on groupBy field
bools_set.insert(g_val);
boolValCount+=1;
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(lastDistance<=distance);//distance should be decreased as metrics_type is L2
lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
}
ASSERT_TRUE(boolValCount<=2);//bool values cannot exceed two
}
}
}
TEST(GroupBY, Reduce){
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
//0. prepare schema
int dim = 64;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
schema->set_primary_field_id(int64_fid);
auto segment1 = CreateSealedSegment(schema);
auto segment2 = CreateSealedSegment(schema);
//1. load raw data
size_t N = 100;
uint64_t seed = 512;
uint64_t ts_offset = 0;
int repeat_count_1 = 2;
int repeat_count_2 = 5;
auto raw_data1 = DataGen(schema, N, seed, ts_offset, repeat_count_1);
auto raw_data2 = DataGen(schema, N, seed, ts_offset, repeat_count_2);
auto fields = schema->get_fields();
//load segment1 raw data
for (auto field_data : raw_data1.raw_->fields_data()) {
int64_t field_id = field_data.field_id();
auto info = FieldDataInfo(field_data.field_id(), N);
auto field_meta = fields.at(FieldId(field_id));
info.channel->push(
CreateFieldDataFromDataArray(N, &field_data, field_meta));
info.channel->close();
segment1->LoadFieldData(FieldId(field_id), info);
}
prepareSegmentSystemFieldData(segment1, N, raw_data1);
//load segment2 raw data
for (auto field_data : raw_data2.raw_->fields_data()) {
int64_t field_id = field_data.field_id();
auto info = FieldDataInfo(field_data.field_id(), N);
auto field_meta = fields.at(FieldId(field_id));
info.channel->push(
CreateFieldDataFromDataArray(N, &field_data, field_meta));
info.channel->close();
segment2->LoadFieldData(FieldId(field_id), info);
}
prepareSegmentSystemFieldData(segment2, N, raw_data2);
//3. load index
auto vector_data_1 = raw_data1.get_col<float>(vec_fid);
auto indexing_1 = GenVecIndexing(N, dim, vector_data_1.data(), knowhere::IndexEnum::INDEX_HNSW);
LoadIndexInfo load_index_info_1;
load_index_info_1.field_id = vec_fid.get();
load_index_info_1.index = std::move(indexing_1);
load_index_info_1.index_params[METRICS_TYPE] = knowhere::metric::L2;
segment1->LoadIndex(load_index_info_1);
auto vector_data_2 = raw_data2.get_col<float>(vec_fid);
auto indexing_2 = GenVecIndexing(N, dim, vector_data_2.data(), knowhere::IndexEnum::INDEX_HNSW);
LoadIndexInfo load_index_info_2;
load_index_info_2.field_id = vec_fid.get();
load_index_info_2.index = std::move(indexing_2);
load_index_info_2.index_params[METRICS_TYPE] = knowhere::metric::L2;
segment2->LoadIndex(load_index_info_2);
//4. search group by respectively
const char* raw_plan = R"(vector_anns: <
field_id: 100
query_info: <
topk: 100
metric_type: "L2"
search_params: "{\"ef\": 10}"
group_by_field_id: 101
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 10;
auto topK = 100;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
CPlaceholderGroup c_ph_group = ph_group.release();
CSearchPlan c_plan = plan.release();
CSegmentInterface c_segment_1 = segment1.release();
CSegmentInterface c_segment_2 = segment2.release();
CSearchResult c_search_res_1;
CSearchResult c_search_res_2;
auto status = Search(c_segment_1, c_plan, c_ph_group, {}, &c_search_res_1);
ASSERT_EQ(status.error_code, Success);
status = Search(c_segment_2, c_plan, c_ph_group, {}, &c_search_res_2);
ASSERT_EQ(status.error_code, Success);
std::vector<CSearchResult> results;
results.push_back(c_search_res_1);
results.push_back(c_search_res_2);
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
CSearchResultDataBlobs cSearchResultData;
status = ReduceSearchResultsAndFillData(
&cSearchResultData,
c_plan,
results.data(),
results.size(),
slice_nqs.data(),
slice_topKs.data(),
slice_nqs.size()
);
CheckSearchResultDuplicate(results);
DeleteSearchResult(c_search_res_1);
DeleteSearchResult(c_search_res_2);
DeleteSearchResultDataBlobs(cSearchResultData);
DeleteSearchPlan(c_plan);
DeletePlaceholderGroup(c_ph_group);
DeleteSegment(c_segment_1);
DeleteSegment(c_segment_2);
}

View File

@ -400,7 +400,7 @@ TEST(Sealed, LoadFieldData) {
auto fakevec = dataset.get_col<float>(fakevec_id);
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
auto segment = CreateSealedSegment(schema);
// std::string dsl = R"({
@ -525,7 +525,7 @@ TEST(Sealed, LoadFieldDataMmap) {
auto fakevec = dataset.get_col<float>(fakevec_id);
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
auto segment = CreateSealedSegment(schema);
const char* raw_plan = R"(vector_anns: <
@ -616,7 +616,7 @@ TEST(Sealed, LoadScalarIndex) {
auto fakevec = dataset.get_col<float>(fakevec_id);
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
auto segment = CreateSealedSegment(schema);
// std::string dsl = R"({
@ -1135,7 +1135,7 @@ TEST(Sealed, GetVector) {
auto fakevec = dataset.get_col<float>(fakevec_id);
auto indexing = GenVecIndexing(N, dim, fakevec.data());
auto indexing = GenVecIndexing(N, dim, fakevec.data(), knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
auto segment_sealed = CreateSealedSegment(schema);

View File

@ -904,7 +904,7 @@ SealedCreator(SchemaPtr schema, const GeneratedData& dataset) {
}
inline std::unique_ptr<milvus::index::VectorIndex>
GenVecIndexing(int64_t N, int64_t dim, const float* vec) {
GenVecIndexing(int64_t N, int64_t dim, const float* vec, const char* index_type) {
auto conf =
knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, std::to_string(dim)},
@ -920,7 +920,7 @@ GenVecIndexing(int64_t N, int64_t dim, const float* vec) {
milvus::storage::FileManagerContext file_manager_context(
field_data_meta, index_meta, chunk_manager);
auto indexing = std::make_unique<index::VectorMemIndex<float>>(
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
index_type,
knowhere::metric::L2,
knowhere::Version::GetCurrentVersion().VersionNumber(),
file_manager_context);

View File

@ -0,0 +1,139 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <boost/format.hpp>
#include <chrono>
#include <iostream>
#include <unordered_set>
#include "common/Types.h"
#include "common/type_c.h"
#include "pb/plan.pb.h"
#include "segcore/Collection.h"
#include "segcore/Reduce.h"
#include "segcore/reduce_c.h"
#include "segcore/segment_c.h"
#include "DataGen.h"
#include "PbHelper.h"
#include "c_api_test_utils.h"
#include "indexbuilder_test_utils.h"
using namespace milvus;
using namespace milvus::segcore;
namespace {
const char*
get_default_schema_config() {
static std::string conf = R"(name: "default-collection"
fields: <
fieldID: 100
name: "fakevec"
data_type: FloatVector
type_params: <
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
fieldID: 101
name: "age"
data_type: Int64
is_primary_key: true
>)";
static std::string fake_conf = "";
return conf.c_str();
}
std::string
generate_max_float_query_data(int all_nq, int max_float_nq) {
assert(max_float_nq <= all_nq);
namespace ser = milvus::proto::common;
int dim = DIM;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::FloatVector);
for (int i = 0; i < all_nq; ++i) {
std::vector<float> vec;
if (i < max_float_nq) {
for (int d = 0; d < dim; ++d) {
vec.push_back(std::numeric_limits<float>::max());
}
} else {
for (int d = 0; d < dim; ++d) {
vec.push_back(1);
}
}
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
return blob;
}
std::string
generate_query_data(int nq) {
namespace ser = milvus::proto::common;
std::default_random_engine e(67);
int dim = DIM;
std::normal_distribution<double> dis(0.0, 1.0);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::FloatVector);
for (int i = 0; i < nq; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
return blob;
}
void
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
auto nq = ((SearchResult*)results[0])->total_nq_;
std::unordered_set<PkType> pk_set;
std::unordered_set<GroupByValueType> group_by_val_set;
for (int qi = 0; qi < nq; qi++) {
pk_set.clear();
group_by_val_set.clear();
for (size_t i = 0; i < results.size(); i++) {
auto search_result = (SearchResult*)results[i];
ASSERT_EQ(nq, search_result->total_nq_);
auto topk_beg = search_result->topk_per_nq_prefix_sum_[qi];
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
for (size_t ki = topk_beg; ki < topk_end; ki++) {
ASSERT_NE(search_result->seg_offsets_[ki], INVALID_SEG_OFFSET);
auto ret = pk_set.insert(search_result->primary_keys_[ki]);
ASSERT_TRUE(ret.second);
if(search_result->group_by_values_.size()>ki){
auto group_by_val = search_result->group_by_values_[ki];
ASSERT_TRUE(group_by_val_set.count(group_by_val)==0);
group_by_val_set.insert(group_by_val);
}
}
}
}
}
}

View File

@ -57,6 +57,7 @@ message QueryInfo {
string metric_type = 3;
string search_params = 4;
int64 round_decimal = 5;
int64 group_by_field_id = 6;
}
message ColumnInfo {

View File

@ -44,6 +44,7 @@ import (
const (
IgnoreGrowingKey = "ignore_growing"
ReduceStopForBestKey = "reduce_stop_for_best"
GroupByFieldKey = "group_by_field"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
NQKey = "nq"

View File

@ -116,7 +116,8 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
}
// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInfo, int64, error) {
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*planpb.QueryInfo, int64, error) {
// 1. parse offset and real topk
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
return nil, 0, errors.New(TopKKey + " not found in search_params")
@ -149,11 +150,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
}
// 2. parse metrics type
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
if err != nil {
metricType = ""
}
// 3. parse round decimal
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
if err != nil {
roundDecimalStr = "-1"
@ -167,15 +170,39 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}
// 4. parse search param str
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
if err != nil {
searchParamStr = ""
}
// 5. parse group by field
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
if err != nil {
groupByFieldName = ""
}
var groupByFieldId int64
if groupByFieldName != "" {
groupByFieldId = -1
fields := schema.GetFields()
for _, field := range fields {
if field.Name == groupByFieldName {
groupByFieldId = field.FieldID
break
}
}
if groupByFieldId == -1 {
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
}
}
return &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
}, offset, nil
}
@ -299,10 +326,14 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
annsField = vecFields[0].Name
}
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams())
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
if err != nil {
return err
}
if queryInfo.GroupByFieldId != 0 {
t.SearchRequest.IgnoreGrowing = true
// for group by operation, currently, we ignore growing segments
}
t.offset = offset
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
@ -405,7 +436,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
zap.Bool("use_default_consistency", useDefaultConsistency),
zap.Any("consistency level", consistencyLevel),
zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp()))
return nil
}
@ -817,7 +847,6 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
default:
return nil, errors.New("unsupported pk type")
}
for i, sData := range subSearchResultData {
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
log.Ctx(ctx).Debug("subSearchResultData",
@ -852,6 +881,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
var (
@ -859,8 +889,9 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
// sum(cursors) == j
cursors = make([]int64, subSearchNum)
j int64
idSet = make(map[interface{}]struct{})
j int64
idSet = make(map[interface{}]struct{})
groupByValSet = make(map[interface{}]struct{})
)
// skip offset results
@ -882,17 +913,32 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
if subSearchIdx == -1 {
break
}
subSearchRes := subSearchResultData[subSearchIdx]
id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
score := subSearchRes.Scores[resultDataIdx]
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
// remove duplicates
if _, ok := idSet[id]; !ok {
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
typeutil.AppendPKs(ret.Results.Ids, id)
ret.Results.Scores = append(ret.Results.Scores, score)
idSet[id] = struct{}{}
j++
groupByValExist := false
if groupByVal != nil {
_, groupByValExist = groupByValSet[groupByVal]
}
if !groupByValExist {
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
typeutil.AppendPKs(ret.Results.Ids, id)
ret.Results.Scores = append(ret.Results.Scores, score)
idSet[id] = struct{}{}
if groupByVal != nil {
groupByValSet[groupByVal] = struct{}{}
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subSearchRes.GetGroupByFieldValue().GetType()); err != nil {
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
return ret, err
}
}
j++
}
} else {
// skip entity with same id
skipDupCnt++

View File

@ -1591,6 +1591,78 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
})
}
func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
var (
nq int64 = 2
topK int64 = 5
)
ids := [][]int64{
{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
{2, 4, 6, 8, 10, 2, 4, 6, 8, 10},
}
scores := [][]float32{
{10, 8, 6, 4, 2, 10, 8, 6, 4, 2},
{9, 7, 5, 3, 1, 9, 7, 5, 3, 1},
}
groupByValuesArr := [][][]int64{
{
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
}, // result2 has completely same group_by values, no result from result2 can be selected
{
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
{6, 8, 3, 4, 5, 6, 8, 3, 4, 5},
}, // result2 will contribute group_by values 6 and 8
}
expectedIDs := [][]int64{
{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
}
expectedScores := [][]float32{
{-10, -8, -6, -4, -2, -10, -8, -6, -4, -2},
{-10, -9, -8, -7, -6, -10, -9, -8, -7, -6},
}
expectedGroupByValues := [][]int64{
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
{1, 6, 2, 8, 3, 1, 6, 2, 8, 3},
}
for i, groupByValues := range groupByValuesArr {
t.Run("Group By correctness", func(t *testing.T) {
var results []*schemapb.SearchResultData
for j := range ids {
result := getSearchResultData(nq, topK)
result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}}
result.Scores = scores[j]
result.Topks = []int64{topK, topK}
result.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: groupByValues[j],
},
},
},
},
}
results = append(results, result)
}
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topK, metric.L2, schemapb.DataType_Int64, 0)
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
resultScores := reduced.GetResults().GetScores()
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
assert.EqualValues(t, expectedIDs[i], resultIDs)
assert.EqualValues(t, expectedScores[i], resultScores)
assert.EqualValues(t, expectedGroupByValues[i], resultGroupByValues)
assert.NoError(t, err)
})
}
}
func TestSearchTask_ErrExecute(t *testing.T) {
var (
err error
@ -1784,7 +1856,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.validParams)
info, offset, err := parseSearchInfo(test.validParams, nil)
assert.NoError(t, err)
assert.NotNil(t, info)
if test.description == "offsetParam" {
@ -1873,7 +1945,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.invalidParams)
info, offset, err := parseSearchInfo(test.invalidParams, nil)
assert.Error(t, err)
assert.Nil(t, info)
assert.Zero(t, offset)

View File

@ -129,6 +129,7 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
offsets := make([]int64, len(searchResultData))
idSet := make(map[interface{}]struct{})
groupByValueSet := make(map[interface{}]struct{})
var j int64
for j = 0; j < topk; {
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
@ -138,15 +139,29 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
idx := resultOffsets[sel][i] + offsets[sel]
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx))
score := searchResultData[sel].Scores[idx]
// remove duplicates
if _, ok := idSet[id]; !ok {
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
typeutil.AppendPKs(ret.Ids, id)
ret.Scores = append(ret.Scores, score)
idSet[id] = struct{}{}
j++
groupByValExist := false
if groupByVal != nil {
_, groupByValExist = groupByValueSet[groupByVal]
}
if !groupByValExist {
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
typeutil.AppendPKs(ret.Ids, id)
ret.Scores = append(ret.Scores, score)
if groupByVal != nil {
groupByValueSet[groupByVal] = struct{}{}
if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil {
log.Error("Failed to append groupByValues", zap.Error(err))
return ret, err
}
}
idSet[id] = struct{}{}
j++
}
} else {
// skip entity with same id
skipDupCnt++

View File

@ -603,6 +603,139 @@ func (suite *ResultSuite) TestResult_ReduceSearchResultData() {
})
}
func (suite *ResultSuite) TestResult_SearchGroupByResult() {
const (
nq = 1
topk = 4
)
suite.Run("reduce_group_by_int", func() {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Int8,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{2, 3, 4, 5},
},
},
},
},
}
data2.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Int8,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{2, 3, 4, 5},
},
},
},
},
}
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
suite.ElementsMatch([]int32{2, 3, 4, 5}, res.GroupByFieldValue.GetScalars().GetIntData().Data)
})
suite.Run("reduce_group_by_bool", func() {
ids1 := []int64{1, 2}
scores1 := []float32{-1.0, -2.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{3, 4}
scores2 := []float32{-1.0, -1.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Bool,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{true, false},
},
},
},
},
}
data2.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Bool,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{true, false},
},
},
},
},
}
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 4}, res.Ids.GetIntId().Data)
suite.ElementsMatch([]float32{-1.0, -1.0}, res.Scores)
suite.ElementsMatch([]bool{true, false}, res.GroupByFieldValue.GetScalars().GetBoolData().Data)
})
suite.Run("reduce_group_by_string", func() {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
topks1 := []int64{int64(len(ids1))}
ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"1", "2", "3", "4"},
},
},
},
},
}
data2.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"1", "2", "3", "4"},
},
},
},
},
}
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := ReduceSearchResultData(context.TODO(), dataArray, nq, topk)
suite.Nil(err)
suite.ElementsMatch([]int64{1, 2, 3, 4}, res.Ids.GetIntId().Data)
suite.ElementsMatch([]float32{-1.0, -2.0, -3.0, -4.0}, res.Scores)
suite.ElementsMatch([]string{"1", "2", "3", "4"}, res.GroupByFieldValue.GetScalars().GetStringData().Data)
})
}
func (suite *ResultSuite) TestResult_SelectSearchResultData_int() {
type args struct {
dataArray []*schemapb.SearchResultData

View File

@ -1084,3 +1084,61 @@ func SelectMinPK[T ResultWithID](results []T, cursors []int64) (int, bool) {
return sel, drainResult
}
func AppendGroupByValue(dstResData *schemapb.SearchResultData,
groupByVal interface{}, srcDataType schemapb.DataType,
) error {
if dstResData.GroupByFieldValue == nil {
dstResData.GroupByFieldValue = &schemapb.FieldData{
Type: srcDataType,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{},
},
}
}
dstScalarField := dstResData.GroupByFieldValue.GetScalars()
switch srcDataType {
case schemapb.DataType_Bool:
if dstScalarField.GetBoolData() == nil {
dstScalarField.Data = &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{},
},
}
}
dstScalarField.GetBoolData().Data = append(dstScalarField.GetBoolData().Data, groupByVal.(bool))
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
if dstScalarField.GetIntData() == nil {
dstScalarField.Data = &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{},
},
}
}
dstScalarField.GetIntData().Data = append(dstScalarField.GetIntData().Data, groupByVal.(int32))
case schemapb.DataType_Int64:
if dstScalarField.GetLongData() == nil {
dstScalarField.Data = &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{},
},
}
}
dstScalarField.GetLongData().Data = append(dstScalarField.GetLongData().Data, groupByVal.(int64))
case schemapb.DataType_VarChar:
if dstScalarField.GetStringData() == nil {
dstScalarField.Data = &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{},
},
}
}
dstScalarField.GetStringData().Data = append(dstScalarField.GetStringData().Data, groupByVal.(string))
default:
log.Error("Not supported field type from group_by value field", zap.String("field type",
srcDataType.String()))
return fmt.Errorf("not supported field type from group_by value field: %s",
srcDataType.String())
}
return nil
}