mirror of https://github.com/milvus-io/milvus.git
related: #29965 Signed-off-by: MrPresent-Han <chun.han@zilliz.com>pull/30163/head
parent
ae4f62ab4b
commit
4436effdc3
|
@ -113,7 +113,7 @@ using GroupByValueType = std::variant<std::monostate,
|
|||
int32_t,
|
||||
int64_t,
|
||||
bool,
|
||||
std::string_view>;
|
||||
std::string>;
|
||||
using ContainsType = proto::plan::JSONContainsExpr_JSONOp;
|
||||
|
||||
inline bool
|
||||
|
|
|
@ -44,11 +44,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
|||
|
||||
switch (data_type) {
|
||||
case DataType::INT8: {
|
||||
auto field_data = segment.chunk_data<int8_t>(group_by_field_id, 0);
|
||||
DataGetter<int8_t> dataGetter(segment, group_by_field_id);
|
||||
GroupIteratorsByType<int8_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
|
@ -56,11 +56,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
|||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
auto field_data = segment.chunk_data<int16_t>(group_by_field_id, 0);
|
||||
DataGetter<int16_t> dataGetter(segment, group_by_field_id);
|
||||
GroupIteratorsByType<int16_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
|
@ -68,11 +68,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
|||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
auto field_data = segment.chunk_data<int32_t>(group_by_field_id, 0);
|
||||
DataGetter<int32_t> dataGetter(segment, group_by_field_id);
|
||||
GroupIteratorsByType<int32_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
|
@ -80,11 +80,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
|||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
auto field_data = segment.chunk_data<int64_t>(group_by_field_id, 0);
|
||||
DataGetter<int64_t> dataGetter(segment, group_by_field_id);
|
||||
GroupIteratorsByType<int64_t>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
|
@ -92,11 +92,11 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
|||
break;
|
||||
}
|
||||
case DataType::BOOL: {
|
||||
auto field_data = segment.chunk_data<bool>(group_by_field_id, 0);
|
||||
DataGetter<bool> dataGetter(segment, group_by_field_id);
|
||||
GroupIteratorsByType<bool>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
|
@ -104,21 +104,20 @@ GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
|||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto field_data =
|
||||
segment.chunk_data<std::string_view>(group_by_field_id, 0);
|
||||
GroupIteratorsByType<std::string_view>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
field_data,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
DataGetter<std::string> dataGetter(segment, group_by_field_id);
|
||||
GroupIteratorsByType<std::string>(iterators,
|
||||
group_by_field_id,
|
||||
search_info.topk_,
|
||||
dataGetter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
search_info.metric_type_);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PanicInfo(
|
||||
DataTypeInvalid,
|
||||
Unsupported,
|
||||
fmt::format("unsupported data type {} for group by operator",
|
||||
data_type));
|
||||
}
|
||||
|
@ -132,7 +131,7 @@ GroupIteratorsByType(
|
|||
iterators,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
std::vector<float>& distances,
|
||||
|
@ -141,7 +140,7 @@ GroupIteratorsByType(
|
|||
GroupIteratorResult<T>(iterator,
|
||||
field_id,
|
||||
topK,
|
||||
field_data,
|
||||
data_getter,
|
||||
group_by_values,
|
||||
seg_offsets,
|
||||
distances,
|
||||
|
@ -155,7 +154,7 @@ GroupIteratorResult(
|
|||
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& offsets,
|
||||
std::vector<float>& distances,
|
||||
|
@ -173,7 +172,7 @@ GroupIteratorResult(
|
|||
};
|
||||
while (iterator->HasNext() && groupMap.size() < topK) {
|
||||
auto [offset, dis] = iterator->Next();
|
||||
const T& row_data = field_data.operator[](offset);
|
||||
T row_data = data_getter.Get(offset);
|
||||
auto it = groupMap.find(row_data);
|
||||
if (it == groupMap.end()) {
|
||||
groupMap.emplace(row_data, std::make_pair(offset, dis));
|
||||
|
|
|
@ -23,6 +23,51 @@
|
|||
|
||||
namespace milvus {
|
||||
namespace query {
|
||||
|
||||
template <typename T>
|
||||
struct DataGetter {
|
||||
std::shared_ptr<Span<T>> field_data_;
|
||||
std::shared_ptr<Span<std::string_view>> str_field_data_;
|
||||
const index::ScalarIndex<T>* field_index_;
|
||||
|
||||
DataGetter(const segcore::SegmentInternalInterface& segment,
|
||||
FieldId& field_id) {
|
||||
if (segment.HasFieldData(field_id)) {
|
||||
if constexpr (std::is_same_v<T, std::string>) {
|
||||
auto span = segment.chunk_data<std::string_view>(field_id, 0);
|
||||
str_field_data_ = std::make_shared<Span<std::string_view>>(
|
||||
span.data(), span.row_count());
|
||||
} else {
|
||||
auto span = segment.chunk_data<T>(field_id, 0);
|
||||
field_data_ =
|
||||
std::make_shared<Span<T>>(span.data(), span.row_count());
|
||||
}
|
||||
} else if (segment.HasIndex(field_id)) {
|
||||
this->field_index_ = &(segment.chunk_scalar_index<T>(field_id, 0));
|
||||
} else {
|
||||
PanicInfo(UnexpectedError,
|
||||
"The segment used to init data getter has no effective "
|
||||
"data source, neither"
|
||||
"index or data");
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
T
|
||||
Get(int64_t idx) const {
|
||||
if (field_data_ || str_field_data_) {
|
||||
if constexpr (std::is_same_v<T, std::string>) {
|
||||
std::string_view str_val_view =
|
||||
str_field_data_->operator[](idx);
|
||||
return std::string(str_val_view.data(), str_val_view.length());
|
||||
}
|
||||
return field_data_->operator[](idx);
|
||||
} else {
|
||||
return (*field_index_).Reverse_Lookup(idx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
GroupBy(const std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>&
|
||||
iterators,
|
||||
|
@ -39,7 +84,7 @@ GroupIteratorsByType(
|
|||
iterators,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& seg_offsets,
|
||||
std::vector<float>& distances,
|
||||
|
@ -51,7 +96,7 @@ GroupIteratorResult(
|
|||
const std::shared_ptr<knowhere::IndexNode::iterator>& iterator,
|
||||
FieldId field_id,
|
||||
int64_t topK,
|
||||
Span<T> field_data,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& offsets,
|
||||
std::vector<float>& distances,
|
||||
|
|
|
@ -541,8 +541,8 @@ ReduceHelper::AssembleGroupByValues(
|
|||
case DataType::VARCHAR: {
|
||||
auto field_data = group_by_values_field->mutable_string_data();
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
std::string_view val =
|
||||
std::get<std::string_view>(group_by_vals[idx]);
|
||||
std::string val =
|
||||
std::move(std::get<std::string>(group_by_vals[idx]));
|
||||
*(field_data->mutable_data()->Add()) = val;
|
||||
}
|
||||
break;
|
||||
|
|
|
@ -337,12 +337,12 @@ TEST(GroupBY, Normal2) {
|
|||
search_result->seg_offsets_.size());
|
||||
|
||||
int size = group_by_values.size();
|
||||
std::unordered_set<std::string_view> strs_set;
|
||||
std::unordered_set<std::string> strs_set;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<std::string_view>(group_by_values[i])) {
|
||||
std::string_view g_val =
|
||||
std::get<std::string_view>(group_by_values[i]);
|
||||
if (std::holds_alternative<std::string>(group_by_values[i])) {
|
||||
std::string g_val =
|
||||
std::move(std::get<std::string>(group_by_values[i]));
|
||||
ASSERT_FALSE(strs_set.count(g_val) >
|
||||
0); //no repetition on groupBy field
|
||||
strs_set.insert(g_val);
|
||||
|
|
Loading…
Reference in New Issue