enhance: support groupby based on scalar-index(#29965) (#30091)

related: #29965

Signed-off-by: MrPresent-Han <chun.han@zilliz.com>
pull/30163/head
MrPresent-Han 2024-01-22 10:50:54 +08:00 committed by GitHub
parent ae4f62ab4b
commit 4436effdc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 78 additions and 34 deletions

View File

@ -113,7 +113,7 @@ using GroupByValueType = std::variant<std::monostate,
int32_t,
int64_t,
bool,
std::string_view>;
std::string>;
using ContainsType = proto::plan::JSONContainsExpr_JSONOp;
inline bool

View File

@ -44,11 +44,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
switch (data_type) {
case DataType::INT8: {
auto field_data = segment.chunk_data<int8_t>(group_by_field_id, 0);
DataGetter<int8_t> dataGetter(segment, group_by_field_id);
GroupIteratorsByType<int8_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
dataGetter,
group_by_values,
seg_offsets,
distances,
@ -56,11 +56,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
break;
}
case DataType::INT16: {
auto field_data = segment.chunk_data<int16_t>(group_by_field_id, 0);
DataGetter<int16_t> dataGetter(segment, group_by_field_id);
GroupIteratorsByType<int16_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
dataGetter,
group_by_values,
seg_offsets,
distances,
@ -68,11 +68,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
break;
}
case DataType::INT32: {
auto field_data = segment.chunk_data<int32_t>(group_by_field_id, 0);
DataGetter<int32_t> dataGetter(segment, group_by_field_id);
GroupIteratorsByType<int32_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
dataGetter,
group_by_values,
seg_offsets,
distances,
@ -80,11 +80,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
break;
}
case DataType::INT64: {
auto field_data = segment.chunk_data<int64_t>(group_by_field_id, 0);
DataGetter<int64_t> dataGetter(segment, group_by_field_id);
GroupIteratorsByType<int64_t>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
dataGetter,
group_by_values,
seg_offsets,
distances,
@ -92,11 +92,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
break;
}
case DataType::BOOL: {
auto field_data = segment.chunk_data<bool>(group_by_field_id, 0);
DataGetter<bool> dataGetter(segment, group_by_field_id);
GroupIteratorsByType<bool>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
dataGetter,
group_by_values,
seg_offsets,
distances,
@ -104,21 +104,20 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
break;
}
case DataType::VARCHAR: {
auto field_data =
segment.chunk_data<std::string_view>(group_by_field_id, 0);
GroupIteratorsByType<std::string_view>(iterators,
group_by_field_id,
search_info.topk_,
field_data,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
DataGetter<std::string> dataGetter(segment, group_by_field_id);
GroupIteratorsByType<std::string>(iterators,
group_by_field_id,
search_info.topk_,
dataGetter,
group_by_values,
seg_offsets,
distances,
search_info.metric_type_);
break;
}
default: {
PanicInfo(
DataTypeInvalid,
Unsupported,
fmt::format("unsupported data type {} for group by operator",
data_type));
}
@ -132,7 +131,7 @@ GroupIteratorsByType(
iterators,
FieldId field_id,
int64_t topK,
Span<T> field_data,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances,
@ -141,7 +140,7 @@ GroupIteratorsByType(
GroupIteratorResult<T>(iterator,
field_id,
topK,
field_data,
data_getter,
group_by_values,
seg_offsets,
distances,
@ -155,7 +154,7 @@ GroupIteratorResult(
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
FieldId field_id,
int64_t topK,
Span<T> field_data,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,
@ -173,7 +172,7 @@ GroupIteratorResult(
};
while (iterator->HasNext() && groupMap.size() < topK) {
auto [offset, dis] = iterator->Next();
const T& row_data = field_data.operator[](offset);
T row_data = data_getter.Get(offset);
auto it = groupMap.find(row_data);
if (it == groupMap.end()) {
groupMap.emplace(row_data, std::make_pair(offset, dis));

View File

@ -23,6 +23,51 @@
namespace milvus {
namespace query {
template <typename T>
struct DataGetter {
std::shared_ptr<Span<T>> field_data_;
std::shared_ptr<Span<std::string_view>> str_field_data_;
const index::ScalarIndex<T>* field_index_;
DataGetter(const segcore::SegmentInternalInterface& segment,
FieldId& field_id) {
if (segment.HasFieldData(field_id)) {
if constexpr (std::is_same_v<T, std::string>) {
auto span = segment.chunk_data<std::string_view>(field_id, 0);
str_field_data_ = std::make_shared<Span<std::string_view>>(
span.data(), span.row_count());
} else {
auto span = segment.chunk_data<T>(field_id, 0);
field_data_ =
std::make_shared<Span<T>>(span.data(), span.row_count());
}
} else if (segment.HasIndex(field_id)) {
this->field_index_ = &(segment.chunk_scalar_index<T>(field_id, 0));
} else {
PanicInfo(UnexpectedError,
"The segment used to init data getter has no effective "
"data source, neither"
"index or data");
}
}
public:
T
Get(int64_t idx) const {
if (field_data_ || str_field_data_) {
if constexpr (std::is_same_v<T, std::string>) {
std::string_view str_val_view =
str_field_data_->operator[](idx);
return std::string(str_val_view.data(), str_val_view.length());
}
return field_data_->operator[](idx);
} else {
return (*field_index_).Reverse_Lookup(idx);
}
}
};
void
GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
iterators,
@ -39,7 +84,7 @@ GroupIteratorsByType(
iterators,
FieldId field_id,
int64_t topK,
Span<T> field_data,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets,
std::vector<float>& distances,
@ -51,7 +96,7 @@ GroupIteratorResult(
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
FieldId field_id,
int64_t topK,
Span<T> field_data,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,

View File

@ -541,8 +541,8 @@ ReduceHelper::AssembleGroupByValues(
case DataType::VARCHAR: {
auto field_data = group_by_values_field->mutable_string_data();
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
std::string_view val =
std::get<std::string_view>(group_by_vals[idx]);
std::string val =
std::move(std::get<std::string>(group_by_vals[idx]));
*(field_data->mutable_data()->Add()) = val;
}
break;

View File

@ -337,12 +337,12 @@ TEST(GroupBY, Normal2) {
search_result->seg_offsets_.size());
int size = group_by_values.size();
std::unordered_set<std::string_view> strs_set;
std::unordered_set<std::string> strs_set;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<std::string_view>(group_by_values[i])) {
std::string_view g_val =
std::get<std::string_view>(group_by_values[i]);
if (std::holds_alternative<std::string>(group_by_values[i])) {
std::string g_val =
std::move(std::get<std::string>(group_by_values[i]));
ASSERT_FALSE(strs_set.count(g_val) >
0); //no repetition on groupBy field
strs_set.insert(g_val);