mirror of https://github.com/milvus-io/milvus.git
fix: validate sparse vector in search request (#32856)
issue: #32368 Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>pull/33070/head
parent
4382cf5283
commit
7c60d725cc
|
@ -18,6 +18,7 @@
|
|||
#include <unistd.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -215,12 +216,32 @@ GetCommonPrefix(const std::string& str1, const std::string& str2) {
|
|||
}
|
||||
|
||||
inline knowhere::sparse::SparseRow<float>
|
||||
CopyAndWrapSparseRow(const void* data, size_t size) {
|
||||
CopyAndWrapSparseRow(const void* data,
|
||||
size_t size,
|
||||
const bool validate = false) {
|
||||
size_t num_elements =
|
||||
size / knowhere::sparse::SparseRow<float>::element_size();
|
||||
knowhere::sparse::SparseRow<float> row(num_elements);
|
||||
std::memcpy(row.data(), data, size);
|
||||
// TODO(SPARSE): validate
|
||||
if (validate) {
|
||||
AssertInfo(size > 0, "Sparse row data should not be empty");
|
||||
AssertInfo(
|
||||
size % knowhere::sparse::SparseRow<float>::element_size() == 0,
|
||||
"Invalid size for sparse row data");
|
||||
for (size_t i = 0; i < num_elements; ++i) {
|
||||
auto element = row[i];
|
||||
AssertInfo(std::isfinite(element.val),
|
||||
"Invalid sparse row: NaN or Inf value");
|
||||
AssertInfo(element.val >= 0, "Invalid sparse row: negative value");
|
||||
AssertInfo(
|
||||
element.id < std::numeric_limits<uint32_t>::max(),
|
||||
"Invalid sparse row: id should be smaller than uint32 max");
|
||||
if (i > 0) {
|
||||
AssertInfo(row[i - 1].id < element.id,
|
||||
"Invalid sparse row: id should be strict ascending");
|
||||
}
|
||||
}
|
||||
}
|
||||
return row;
|
||||
}
|
||||
|
||||
|
@ -228,15 +249,18 @@ CopyAndWrapSparseRow(const void* data, size_t size) {
|
|||
// sparse float row. This helper function converts such byte arrays into a list
|
||||
// of knowhere::sparse::SparseRow<float>. The resulting list is a deep copy of
|
||||
// the source data.
|
||||
//
|
||||
// Here in segcore we validate the sparse row data only for search requests,
|
||||
// as the insert/upsert data are already validated in go code.
|
||||
template <typename Iterable>
|
||||
std::unique_ptr<knowhere::sparse::SparseRow<float>[]>
|
||||
SparseBytesToRows(const Iterable& rows) {
|
||||
SparseBytesToRows(const Iterable& rows, const bool validate = false) {
|
||||
AssertInfo(rows.size() > 0, "at least 1 sparse row should be provided");
|
||||
auto res =
|
||||
std::make_unique<knowhere::sparse::SparseRow<float>[]>(rows.size());
|
||||
for (size_t i = 0; i < rows.size(); ++i) {
|
||||
res[i] =
|
||||
std::move(CopyAndWrapSparseRow(rows[i].data(), rows[i].size()));
|
||||
res[i] = std::move(
|
||||
CopyAndWrapSparseRow(rows[i].data(), rows[i].size(), validate));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -56,7 +56,8 @@ ParsePlaceholderGroup(const Plan* plan,
|
|||
AssertInfo(element.num_of_queries_ > 0, "must have queries");
|
||||
if (info.type() ==
|
||||
milvus::proto::common::PlaceholderType::SparseFloatVector) {
|
||||
element.sparse_matrix_ = SparseBytesToRows(info.values());
|
||||
element.sparse_matrix_ =
|
||||
SparseBytesToRows(info.values(), /*validate=*/true);
|
||||
} else {
|
||||
auto line_size = info.values().Get(0).size();
|
||||
if (field_meta.get_sizeof() != line_size) {
|
||||
|
|
Loading…
Reference in New Issue