feat: [Sparse Float Vector] segcore to support sparse vector search and get raw vector by id (#30629)

This PR adds the ability to search/get sparse float vectors in segcore,
and added unit tests by modifying lots of existing tests into
parameterized ones.

https://github.com/milvus-io/milvus/issues/29419

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
pull/31170/head
Buqian Zheng 2024-03-13 00:16:30 +08:00 committed by GitHub
parent c8c906b939
commit 96cfae55a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 1241 additions and 858 deletions

View File

@ -425,7 +425,7 @@ class FieldDataSparseVectorImpl
}
private:
int64_t vec_dim_;
int64_t vec_dim_ = 0;
};
class FieldDataArrayImpl : public FieldDataImpl<Array, true> {

View File

@ -179,6 +179,12 @@ using IndexVersion = knowhere::IndexVersion;
// TODO :: type define milvus index type(vector index type and scalar index type)
using IndexType = knowhere::IndexType;
inline bool
IndexIsSparse(const IndexType& index_type) {
return index_type == knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND;
}
// Plus 1 because we can't use greater(>) symbol
constexpr size_t REF_SIZE_THRESHOLD = 16 + 1;

View File

@ -241,22 +241,24 @@ SparseBytesToRows(const Iterable& rows) {
return res;
}
// SparseRowsToProto converts a vector of knowhere::sparse::SparseRow<float> to
// SparseRowsToProto converts a list of knowhere::sparse::SparseRow<float> to
// a milvus::proto::schema::SparseFloatArray. The resulting proto is a deep copy
// of the source data.
inline void SparseRowsToProto(const knowhere::sparse::SparseRow<float>* source,
int64_t rows,
milvus::proto::schema::SparseFloatArray* proto) {
// of the source data. source(i) returns the i-th row to be copied.
inline void SparseRowsToProto(
const std::function<const knowhere::sparse::SparseRow<float>*(size_t)>&
source,
int64_t rows,
milvus::proto::schema::SparseFloatArray* proto) {
int64_t max_dim = 0;
for (size_t i = 0; i < rows; ++i) {
if (source + i == nullptr) {
const auto* row = source(i);
if (row == nullptr) {
// empty row
proto->add_contents();
continue;
}
auto& row = source[i];
max_dim = std::max(max_dim, row.dim());
proto->add_contents(row.data(), row.data_byte_size());
max_dim = std::max(max_dim, row->dim());
proto->add_contents(row->data(), row->data_byte_size());
}
proto->set_dim(max_dim);
}

View File

@ -68,33 +68,6 @@ template <typename T>
constexpr bool IsSparse = std::is_same_v<T, SparseFloatVector> ||
std::is_same_v<T, knowhere::sparse::SparseRow<float>>;
template <typename T, typename Enabled = void>
struct EmbeddedTypeImpl;
template <typename T>
struct EmbeddedTypeImpl<T, std::enable_if_t<IsScalar<T>>> {
using type = T;
};
template <typename T>
struct EmbeddedTypeImpl<T, std::enable_if_t<IsVector<T>>> {
using type = std::conditional_t<
std::is_same_v<T, FloatVector>,
float,
std::conditional_t<
std::is_same_v<T, Float16Vector>,
float16,
std::conditional_t<
std::is_same_v<T, BFloat16Vector>,
bfloat16,
std::conditional_t<std::is_same_v<T, SparseFloatVector>,
void,
uint8_t>>>>;
};
template <typename T>
using EmbeddedType = typename EmbeddedTypeImpl<T>::type;
struct FundamentalTag {};
struct StringTag {};

View File

@ -416,6 +416,11 @@ VectorDiskAnnIndex<T>::HasRawData() const {
template <typename T>
std::vector<uint8_t>
VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset) const {
auto index_type = GetIndexType();
if (IndexIsSparse(index_type)) {
PanicInfo(ErrorCode::UnexpectedError,
"failed to get vector, index is sparse");
}
auto res = index_.GetVectorByIds(*dataset);
if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
@ -423,7 +428,6 @@ VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset) const {
KnowhereStatusString(res.error()),
res.what()));
}
auto index_type = GetIndexType();
auto tensor = res.value()->GetTensor();
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();

View File

@ -98,8 +98,13 @@ class VectorDiskAnnIndex : public VectorIndex {
std::vector<uint8_t>
GetVector(const DatasetPtr dataset) const override;
void
CleanLocalData() override;
std::unique_ptr<const knowhere::sparse::SparseRow<float>[]>
GetSparseVector(const DatasetPtr dataset) const override {
PanicInfo(ErrorCode::Unsupported,
"get sparse vector not supported for disk index");
}
void CleanLocalData() override;
knowhere::expected<
std::vector<std::shared_ptr<knowhere::IndexNode::iterator>>>

View File

@ -76,6 +76,9 @@ class VectorIndex : public IndexBase {
virtual std::vector<uint8_t>
GetVector(const DatasetPtr dataset) const = 0;
virtual std::unique_ptr<const knowhere::sparse::SparseRow<float>[]>
GetSparseVector(const DatasetPtr dataset) const = 0;
IndexType
GetIndexType() const {
return index_type_;

View File

@ -491,7 +491,7 @@ VectorMemIndex<T>::Build(const Config& config) {
build_config.update(config);
build_config.erase("insert_files");
build_config.erase(VEC_OPT_FIELDS);
if (GetIndexType().find("SPARSE") == std::string::npos) {
if (!IndexIsSparse(GetIndexType())) {
int64_t total_size = 0;
int64_t total_num_rows = 0;
int64_t dim = 0;
@ -537,6 +537,7 @@ VectorMemIndex<T>::Build(const Config& config) {
// this does a deep copy of field_data's data.
// TODO: avoid copying by enforcing field data to give up
// ownership.
AssertInfo(dim >= ptr[i].dim(), "bad dim");
vec[offset + i] = ptr[i];
}
offset += field_data->Length();
@ -639,12 +640,17 @@ VectorMemIndex<T>::HasRawData() const {
template <typename T>
std::vector<uint8_t>
VectorMemIndex<T>::GetVector(const DatasetPtr dataset) const {
auto index_type = GetIndexType();
if (IndexIsSparse(index_type)) {
PanicInfo(ErrorCode::UnexpectedError,
"failed to get vector, index is sparse");
}
auto res = index_.GetVectorByIds(*dataset);
if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
"failed to get vector, " + KnowhereStatusString(res.error()));
}
auto index_type = GetIndexType();
auto tensor = res.value()->GetTensor();
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
@ -661,8 +667,22 @@ VectorMemIndex<T>::GetVector(const DatasetPtr dataset) const {
}
template <typename T>
void
VectorMemIndex<T>::LoadFromFile(const Config& config) {
std::unique_ptr<const knowhere::sparse::SparseRow<float>[]>
VectorMemIndex<T>::GetSparseVector(const DatasetPtr dataset) const {
auto res = index_.GetVectorByIds(*dataset);
if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
"failed to get vector, " + KnowhereStatusString(res.error()));
}
// release and transfer ownership to the result unique ptr.
res.value()->SetIsOwner(false);
return std::unique_ptr<const knowhere::sparse::SparseRow<float>[]>(
static_cast<const knowhere::sparse::SparseRow<float>*>(
res.value()->GetTensor()));
}
template <typename T>
void VectorMemIndex<T>::LoadFromFile(const Config& config) {
auto filepath = GetValueFromConfig<std::string>(config, kMmapFilepath);
AssertInfo(filepath.has_value(), "mmap filepath is empty when load index");

View File

@ -85,6 +85,9 @@ class VectorMemIndex : public VectorIndex {
std::vector<uint8_t>
GetVector(const DatasetPtr dataset) const override;
std::unique_ptr<const knowhere::sparse::SparseRow<float>[]>
GetSparseVector(const DatasetPtr dataset) const override;
BinarySet
Upload(const Config& config = {}) override;

View File

@ -45,7 +45,9 @@ class ColumnBase {
public:
// memory mode ctor
ColumnBase(size_t reserve, const FieldMeta& field_meta)
: type_size_(field_meta.get_sizeof()) {
: type_size_(datatype_is_sparse_vector(field_meta.get_data_type())
? 1
: field_meta.get_sizeof()) {
// simdjson requires a padding following the json data
padding_ = field_meta.get_data_type() == DataType::JSON
? simdjson::SIMDJSON_PADDING
@ -55,7 +57,7 @@ class ColumnBase {
return;
}
cap_size_ = field_meta.get_sizeof() * reserve;
cap_size_ = type_size_ * reserve;
// use anon mapping so we are able to free these memory with munmap only
data_ = static_cast<char*>(mmap(nullptr,
@ -72,8 +74,10 @@ class ColumnBase {
// mmap mode ctor
ColumnBase(const File& file, size_t size, const FieldMeta& field_meta)
: type_size_(field_meta.get_sizeof()),
num_rows_(size / field_meta.get_sizeof()) {
: type_size_(datatype_is_sparse_vector(field_meta.get_data_type())
? 1
: field_meta.get_sizeof()),
num_rows_(size / type_size_) {
padding_ = field_meta.get_data_type() == DataType::JSON
? simdjson::SIMDJSON_PADDING
: 0;

View File

@ -15,6 +15,7 @@
// limitations under the License.
#include "Plan.h"
#include "common/Utils.h"
#include "PlanProto.h"
#include "generated/ShowPlanNodeVisitor.h"
@ -34,9 +35,8 @@ std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan,
const uint8_t* blob,
const int64_t blob_len) {
namespace set = milvus::proto::common;
auto result = std::make_unique<PlaceholderGroup>();
set::PlaceholderGroup ph_group;
milvus::proto::common::PlaceholderGroup ph_group;
auto ok = ph_group.ParseFromArray(blob, blob_len);
Assert(ok);
for (auto& info : ph_group.placeholders()) {
@ -46,22 +46,26 @@ ParsePlaceholderGroup(const Plan* plan,
auto field_id = plan->tag2field_.at(element.tag_);
auto& field_meta = plan->schema_[field_id];
element.num_of_queries_ = info.values_size();
AssertInfo(element.num_of_queries_, "must have queries");
Assert(element.num_of_queries_ > 0);
element.line_sizeof_ = info.values().Get(0).size();
if (field_meta.get_sizeof() != element.line_sizeof_) {
throw SegcoreError(
DimNotMatch,
fmt::format("vector dimension mismatch, expected vector "
"size(byte) {}, actual {}.",
field_meta.get_sizeof(),
element.line_sizeof_));
}
auto& target = element.blob_;
target.reserve(element.line_sizeof_ * element.num_of_queries_);
for (auto& line : info.values()) {
Assert(element.line_sizeof_ == line.size());
target.insert(target.end(), line.begin(), line.end());
AssertInfo(element.num_of_queries_ > 0, "must have queries");
if (info.type() ==
milvus::proto::common::PlaceholderType::SparseFloatVector) {
element.sparse_matrix_ = SparseBytesToRows(info.values());
} else {
auto line_size = info.values().Get(0).size();
if (field_meta.get_sizeof() != line_size) {
throw SegcoreError(
DimNotMatch,
fmt::format("vector dimension mismatch, expected vector "
"size(byte) {}, actual {}.",
field_meta.get_sizeof(),
line_size));
}
auto& target = element.blob_;
target.reserve(line_size * element.num_of_queries_);
for (auto& line : info.values()) {
Assert(line_size == line.size());
target.insert(target.end(), line.begin(), line.end());
}
}
result->emplace_back(std::move(element));
}

View File

@ -64,19 +64,30 @@ struct Plan {
struct Placeholder {
std::string tag_;
int64_t num_of_queries_;
int64_t line_sizeof_;
aligned_vector<char> blob_;
// TODO(SPARSE): add a dim_ field here, use the dim passed in search request
// instead of the dim in schema, since the dim of sparse float column is
// dynamic. This change will likely affect lots of code, thus I'll do it in
// a separate PR, and use dim=0 for sparse float vector searches for now.
template <typename T>
const T*
// only one of blob_ and sparse_matrix_ should be set. blob_ is used for
// dense vector search and sparse_matrix_ is for sparse vector search.
aligned_vector<char> blob_;
std::unique_ptr<knowhere::sparse::SparseRow<float>[]> sparse_matrix_;
const void*
get_blob() const {
return reinterpret_cast<const T*>(blob_.data());
if (blob_.empty()) {
return sparse_matrix_.get();
}
return blob_.data();
}
template <typename T>
T*
void*
get_blob() {
return reinterpret_cast<T*>(blob_.data());
if (blob_.empty()) {
return sparse_matrix_.get();
}
return blob_.data();
}
};

View File

@ -67,6 +67,12 @@ struct BFloat16VectorANNS : VectorPlanNode {
accept(PlanNodeVisitor&) override;
};
struct SparseFloatVectorANNS : VectorPlanNode {
public:
void
accept(PlanNodeVisitor&) override;
};
struct RetrievePlanNode : PlanNode {
public:
void

View File

@ -217,6 +217,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::BFloat16Vector) {
return std::make_unique<BFloat16VectorANNS>();
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::SparseFloatVector) {
return std::make_unique<SparseFloatVectorANNS>();
} else {
return std::make_unique<FloatVectorANNS>();
}

View File

@ -36,7 +36,8 @@ CheckBruteForceSearchParam(const FieldMeta& field,
"[BruteForceSearch] Data type isn't vector type");
bool is_float_data_type = (data_type == DataType::VECTOR_FLOAT ||
data_type == DataType::VECTOR_FLOAT16 ||
data_type == DataType::VECTOR_BFLOAT16);
data_type == DataType::VECTOR_BFLOAT16 ||
data_type == DataType::VECTOR_SPARSE_FLOAT);
bool is_float_metric_type = IsFloatMetricType(metric_type);
AssertInfo(is_float_data_type == is_float_metric_type,
"[BruteForceSearch] Data type and metric type miss-match");
@ -86,7 +87,25 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk);
if (search_cfg.contains(RADIUS)) {
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
// TODO(SPARSE): support sparse brute force range search
AssertInfo(
!search_cfg.contains(RADIUS) && !search_cfg.contains(RANGE_FILTER),
"sparse vector not support range search");
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
auto stat = knowhere::BruteForce::SearchSparseWithBuf(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset);
milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf");
if (stat != knowhere::Status::success) {
throw SegcoreError(KnowhereError, KnowhereStatusString(stat));
}
} else if (search_cfg.contains(RADIUS)) {
if (search_cfg.contains(RANGE_FILTER)) {
CheckRangeSearchParam(search_cfg[RADIUS],
search_cfg[RANGE_FILTER],
@ -196,6 +215,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
base_dataset, query_dataset, search_cfg, bitset);
break;
default:
// TODO(SPARSE): support sparse brute force iterator
PanicInfo(ErrorCode::Unsupported,
"Unsupported dataType for chunk brute force iterator:{}",
data_type);

View File

@ -32,14 +32,19 @@ FloatSegmentIndexSearch(const segcore::SegmentGrowingImpl& segment,
auto vecfield_id = info.field_id_;
auto& field = schema[vecfield_id];
auto is_sparse = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT;
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
auto dim = is_sparse ? 0 : field.get_dim();
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT,
"[FloatSearch]Field data type isn't VECTOR_FLOAT");
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT ||
field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT,
"[FloatSearch]Field data type isn't VECTOR_FLOAT or "
"VECTOR_SPARSE_FLOAT");
dataset::SearchDataset search_dataset{info.metric_type_,
num_queries,
info.topk_,
info.round_decimal_,
field.get_dim(),
dim,
query_data};
if (indexing_record.is_in(vecfield_id)) {
const auto& field_indexing =
@ -48,8 +53,12 @@ FloatSegmentIndexSearch(const segcore::SegmentGrowingImpl& segment,
auto indexing = field_indexing.get_segment_indexing();
SearchInfo search_conf = field_indexing.get_search_params(info);
auto vec_index = dynamic_cast<index::VectorIndex*>(indexing);
SearchOnIndex(
search_dataset, *vec_index, search_conf, bitset, search_result);
SearchOnIndex(search_dataset,
*vec_index,
search_conf,
bitset,
search_result,
is_sparse);
}
}
@ -76,7 +85,6 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
AssertInfo(datatype_is_vector(data_type),
"[SearchOnGrowing]Data type isn't vector type");
auto dim = field.get_dim();
auto topk = info.topk_;
auto metric_type = info.metric_type_;
auto round_decimal = info.round_decimal_;
@ -87,6 +95,10 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
segment, info, query_data, num_queries, bitset, search_result);
} else {
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT
? 0
: field.get_dim();
dataset::SearchDataset search_dataset{
metric_type, num_queries, topk, round_decimal, dim, query_data};
std::shared_lock<std::shared_mutex> read_chunk_mutex(

View File

@ -18,12 +18,14 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
const index::VectorIndex& indexing,
const SearchInfo& search_conf,
const BitsetView& bitset,
SearchResult& search_result) {
SearchResult& search_result,
bool is_sparse) {
auto num_queries = search_dataset.num_queries;
auto dim = search_dataset.dim;
auto metric_type = search_dataset.metric_type;
auto dataset =
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
dataset->SetIsSparse(is_sparse);
if (!PrepareVectorIteratorsFromIndex(search_conf,
num_queries,
dataset,

View File

@ -24,6 +24,7 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
const index::VectorIndex& indexing,
const SearchInfo& search_conf,
const BitsetView& bitset,
SearchResult& search_result);
SearchResult& search_result,
bool is_sparse = false);
} // namespace milvus::query

View File

@ -34,8 +34,9 @@ SearchOnSealedIndex(const Schema& schema,
auto field_id = search_info.field_id_;
auto& field = schema[field_id];
// Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
auto is_sparse = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT;
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
auto dim = is_sparse ? 0 : field.get_dim();
AssertInfo(record.is_ready(field_id), "[SearchOnSealed]Record isn't ready");
// Keep the field_indexing smart pointer, until all reference by raw dropped.
@ -44,6 +45,7 @@ SearchOnSealedIndex(const Schema& schema,
"Metric type of field index isn't the same with search info");
auto dataset = knowhere::GenDataSet(num_queries, dim, query_data);
dataset->SetIsSparse(is_sparse);
auto vec_index =
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
if (!PrepareVectorIteratorsFromIndex(search_info,
@ -80,11 +82,16 @@ SearchOnSealed(const Schema& schema,
auto field_id = search_info.field_id_;
auto& field = schema[field_id];
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT
? 0
: field.get_dim();
query::dataset::SearchDataset dataset{search_info.metric_type_,
num_queries,
search_info.topk_,
search_info.round_decimal_,
field.get_dim(),
dim,
query_data};
auto data_type = field.get_data_type();

View File

@ -34,6 +34,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
void
visit(BFloat16VectorANNS& node) override;
void
visit(SparseFloatVectorANNS& node) override;
void
visit(RetrievePlanNode& node) override;

View File

@ -30,6 +30,9 @@ class ExtractInfoPlanNodeVisitor : public PlanNodeVisitor {
void
visit(BFloat16VectorANNS& node) override;
void
visit(SparseFloatVectorANNS& node) override;
void
visit(RetrievePlanNode& node) override;

View File

@ -35,6 +35,11 @@ BFloat16VectorANNS::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this);
}
void
SparseFloatVectorANNS::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this);
}
void
RetrievePlanNode::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this);

View File

@ -31,6 +31,9 @@ class PlanNodeVisitor {
virtual void
visit(BFloat16VectorANNS&) = 0;
virtual void
visit(SparseFloatVectorANNS&) = 0;
virtual void
visit(RetrievePlanNode&) = 0;
};

View File

@ -34,6 +34,9 @@ class ShowPlanNodeVisitor : public PlanNodeVisitor {
void
visit(BFloat16VectorANNS& node) override;
void
visit(SparseFloatVectorANNS& node) override;
void
visit(RetrievePlanNode& node) override;

View File

@ -33,6 +33,9 @@ class VerifyPlanNodeVisitor : public PlanNodeVisitor {
void
visit(BFloat16VectorANNS& node) override;
void
visit(SparseFloatVectorANNS& node) override;
void
visit(RetrievePlanNode& node) override;

View File

@ -149,7 +149,7 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
AssertInfo(segment, "support SegmentSmallIndex Only");
SearchResult search_result;
auto& ph = placeholder_group_->at(0);
auto src_data = ph.get_blob<EmbeddedType<VectorType>>();
auto src_data = ph.get_blob();
auto num_queries = ph.num_of_queries_;
// TODO: add API to unify row_count
@ -308,4 +308,9 @@ ExecPlanNodeVisitor::visit(BFloat16VectorANNS& node) {
VectorVisitorImpl<BFloat16Vector>(node);
}
void
ExecPlanNodeVisitor::visit(SparseFloatVectorANNS& node) {
VectorVisitorImpl<SparseFloatVector>(node);
}
} // namespace milvus::query

View File

@ -65,6 +65,15 @@ ExtractInfoPlanNodeVisitor::visit(BFloat16VectorANNS& node) {
}
}
void
ExtractInfoPlanNodeVisitor::visit(SparseFloatVectorANNS& node) {
plan_info_.add_involved_field(node.search_info_.field_id_);
if (node.predicate_.has_value()) {
ExtractInfoExprVisitor expr_visitor(plan_info_);
node.predicate_.value()->accept(expr_visitor);
}
}
void
ExtractInfoPlanNodeVisitor::visit(RetrievePlanNode& node) {
// Assert(node.predicate_.has_value());

View File

@ -144,6 +144,30 @@ ShowPlanNodeVisitor::visit(BFloat16VectorANNS& node) {
ret_ = json_body;
}
void
ShowPlanNodeVisitor::visit(SparseFloatVectorANNS& node) {
assert(!ret_);
auto& info = node.search_info_;
Json json_body{
{"node_type", "SparseFloatVectorANNS"}, //
{"metric_type", info.metric_type_}, //
{"field_id_", info.field_id_.get()}, //
{"topk", info.topk_}, //
{"search_params", info.search_params_}, //
{"placeholder_tag", node.placeholder_tag_}, //
};
if (node.predicate_.has_value()) {
ShowExprVisitor expr_show;
AssertInfo(node.predicate_.value(),
"[ShowPlanNodeVisitor]Can't get value from node predict");
json_body["predicate"] =
expr_show.call_child(node.predicate_->operator*());
} else {
json_body["predicate"] = "None";
}
ret_ = json_body;
}
void
ShowPlanNodeVisitor::visit(RetrievePlanNode& node) {
}

View File

@ -42,6 +42,10 @@ void
VerifyPlanNodeVisitor::visit(BFloat16VectorANNS&) {
}
void
VerifyPlanNodeVisitor::visit(SparseFloatVectorANNS&) {
}
void
VerifyPlanNodeVisitor::visit(RetrievePlanNode&) {
}

View File

@ -129,6 +129,9 @@ class VectorBase {
virtual bool
empty() = 0;
virtual void
clear() = 0;
protected:
const int64_t size_per_chunk_;
};
@ -282,7 +285,7 @@ class ConcurrentVectorImpl : public VectorBase {
}
void
clear() {
clear() override {
chunks_.clear();
}

View File

@ -70,6 +70,9 @@ VectorFieldIndexing::BuildIndexRange(int64_t ack_beg,
}
}
// for sparse float vector:
// * element_size is not used
// * output_raw pooints at a milvus::schema::proto::SparseFloatArray.
void
VectorFieldIndexing::GetDataFromIndex(const int64_t* seg_offsets,
int64_t count,
@ -80,10 +83,16 @@ VectorFieldIndexing::GetDataFromIndex(const int64_t* seg_offsets,
ids_ds->SetDim(1);
ids_ds->SetIds(seg_offsets);
ids_ds->SetIsOwner(false);
auto vector = index_->GetVector(ids_ds);
std::memcpy(output, vector.data(), count * element_size);
if (field_meta_.get_data_type() == DataType::VECTOR_SPARSE_FLOAT) {
auto vector = index_->GetSparseVector(ids_ds);
SparseRowsToProto(
[vec_ptr = vector.get()](size_t i) { return vec_ptr + i; },
count,
reinterpret_cast<milvus::proto::schema::SparseFloatArray*>(output));
} else {
auto vector = index_->GetVector(ids_ds);
std::memcpy(output, vector.data(), count * element_size);
}
}
void
@ -242,7 +251,9 @@ VectorFieldIndexing::AppendSegmentIndexDense(int64_t reserved_offset,
knowhere::Json
VectorFieldIndexing::get_build_params() const {
auto config = config_->GetBuildBaseParams();
config[knowhere::meta::DIM] = std::to_string(field_meta_.get_dim());
if (!datatype_is_sparse_vector(field_meta_.get_data_type())) {
config[knowhere::meta::DIM] = std::to_string(field_meta_.get_dim());
}
config[knowhere::meta::NUM_BUILD_THREAD] = std::to_string(1);
// for sparse float vector: drop_ratio_build config is not allowed to be set
// on growing segment index.
@ -255,10 +266,6 @@ VectorFieldIndexing::get_search_params(const SearchInfo& searchInfo) const {
return conf;
}
idx_t
VectorFieldIndexing::get_index_cursor() {
return index_cur_.load();
}
bool
VectorFieldIndexing::sync_data_with_index() const {
return sync_with_index_.load();

View File

@ -86,9 +86,6 @@ class FieldIndexing {
return field_meta_;
}
virtual idx_t
get_index_cursor() = 0;
int64_t
get_size_per_chunk() const {
return segcore_config_.get_chunk_rows();
@ -143,10 +140,6 @@ class ScalarFieldIndexing : public FieldIndexing {
PanicInfo(Unsupported,
"scalar index don't support get data from index");
}
idx_t
get_index_cursor() override {
return 0;
}
int64_t
get_build_threshold() const override {
@ -201,6 +194,9 @@ class VectorFieldIndexing : public FieldIndexing {
const VectorBase* field_raw_data,
const void* data_source) override;
// for sparse float vector:
// * element_size is not used
// * output_raw pooints at a milvus::schema::proto::SparseFloatArray.
void
GetDataFromIndex(const int64_t* seg_offsets,
int64_t count,
@ -229,9 +225,6 @@ class VectorFieldIndexing : public FieldIndexing {
bool
has_raw_data() const override;
idx_t
get_index_cursor() override;
knowhere::Json
get_build_params() const;
@ -370,6 +363,9 @@ class IndexingRecord {
}
}
// for sparse float vector:
// * element_size is not used
// * output_raw pooints at a milvus::schema::proto::SparseFloatArray.
void
GetDataFromIndex(FieldId fieldId,
const int64_t* seg_offsets,
@ -378,9 +374,10 @@ class IndexingRecord {
void* output_raw) const {
if (is_in(fieldId)) {
auto& indexing = field_indexings_.at(fieldId);
if (indexing->get_field_meta().is_vector() &&
if (indexing->get_field_meta().get_data_type() ==
DataType::VECTOR_FLOAT ||
indexing->get_field_meta().get_data_type() ==
DataType::VECTOR_FLOAT) {
DataType::VECTOR_SPARSE_FLOAT) {
indexing->GetDataFromIndex(
seg_offsets, count, element_size, output_raw);
}

View File

@ -38,7 +38,7 @@ class VecIndexConfig {
{knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, 0.1}};
inline static const std::unordered_set<std::string> maintain_params = {
"radius", "range_filter"};
"radius", "range_filter", "drop_ratio_search"};
public:
VecIndexConfig(const int64_t max_index_row_count,

View File

@ -552,7 +552,7 @@ struct InsertRecord {
return ptr;
}
// append a column of scalar type
// append a column of scalar or sparse float vector type
template <typename Type>
void
append_field_data(FieldId field_id, int64_t size_per_chunk) {

View File

@ -71,9 +71,14 @@ void
SegmentGrowingImpl::try_remove_chunks(FieldId fieldId) {
//remove the chunk data to reduce memory consumption
if (indexing_record_.SyncDataWithIndex(fieldId)) {
auto vec_data_base =
VectorBase* vec_data_base =
dynamic_cast<segcore::ConcurrentVector<FloatVector>*>(
insert_record_.get_field_data_base(fieldId));
if (!vec_data_base) {
vec_data_base =
dynamic_cast<segcore::ConcurrentVector<SparseFloatVector>*>(
insert_record_.get_field_data_base(fieldId));
}
if (vec_data_base && vec_data_base->num_chunk() > 0 &&
chunk_mutex_.try_lock()) {
vec_data_base->clear();
@ -487,6 +492,16 @@ SegmentGrowingImpl::bulk_subscript(FieldId field_id,
seg_offsets,
count,
result->mutable_vectors()->mutable_bfloat16_vector()->data());
} else if (field_meta.get_data_type() ==
DataType::VECTOR_SPARSE_FLOAT) {
bulk_subscript_sparse_float_vector_impl(
field_id,
(const ConcurrentVector<SparseFloatVector>*)vec_ptr,
seg_offsets,
count,
result->mutable_vectors()->mutable_sparse_float_vector());
result->mutable_vectors()->set_dim(
result->vectors().sparse_float_vector().dim());
} else {
PanicInfo(DataTypeInvalid, "logical error");
}
@ -603,6 +618,33 @@ SegmentGrowingImpl::bulk_subscript(FieldId field_id,
return result;
}
void
SegmentGrowingImpl::bulk_subscript_sparse_float_vector_impl(
FieldId field_id,
const ConcurrentVector<SparseFloatVector>* vec_raw,
const int64_t* seg_offsets,
int64_t count,
milvus::proto::schema::SparseFloatArray* output) const {
AssertInfo(HasRawData(field_id.get()), "Growing segment loss raw data");
// if index has finished building index, grab from index
if (indexing_record_.SyncDataWithIndex(field_id)) {
indexing_record_.GetDataFromIndex(
field_id, seg_offsets, count, 0, output);
return;
}
// else copy from raw data
std::lock_guard<std::shared_mutex> guard(chunk_mutex_);
SparseRowsToProto(
[&](size_t i) {
auto offset = seg_offsets[i];
return offset != INVALID_SEG_OFFSET ? vec_raw->get_element(offset)
: nullptr;
},
count,
output);
}
template <typename S, typename T>
void
SegmentGrowingImpl::bulk_subscript_ptr_impl(
@ -631,32 +673,27 @@ SegmentGrowingImpl::bulk_subscript_impl(FieldId field_id,
AssertInfo(vec_ptr, "Pointer of vec_raw is nullptr");
auto& vec = *vec_ptr;
auto copy_from_chunk = [&]() {
auto output_base = reinterpret_cast<char*>(output_raw);
for (int i = 0; i < count; ++i) {
auto dst = output_base + i * element_sizeof;
auto offset = seg_offsets[i];
if (offset == INVALID_SEG_OFFSET) {
memset(dst, 0, element_sizeof);
} else {
auto src = (const uint8_t*)vec.get_element(offset);
memcpy(dst, src, element_sizeof);
}
}
};
//HasRawData interface guarantees that data can be fetched from growing segment
if (HasRawData(field_id.get())) {
//When data sync with index
if (indexing_record_.SyncDataWithIndex(field_id)) {
indexing_record_.GetDataFromIndex(
field_id, seg_offsets, count, element_sizeof, output_raw);
// HasRawData interface guarantees that data can be fetched from growing segment
AssertInfo(HasRawData(field_id.get()), "Growing segment loss raw data");
// when data is in sync with index
if (indexing_record_.SyncDataWithIndex(field_id)) {
indexing_record_.GetDataFromIndex(
field_id, seg_offsets, count, element_sizeof, output_raw);
return;
}
// else copy from chunk
std::lock_guard<std::shared_mutex> guard(chunk_mutex_);
auto output_base = reinterpret_cast<char*>(output_raw);
for (int i = 0; i < count; ++i) {
auto dst = output_base + i * element_sizeof;
auto offset = seg_offsets[i];
if (offset == INVALID_SEG_OFFSET) {
memset(dst, 0, element_sizeof);
} else {
//Else copy from chunk
std::lock_guard<std::shared_mutex> guard(chunk_mutex_);
copy_from_chunk();
auto src = (const uint8_t*)vec.get_element(offset);
memcpy(dst, src, element_sizeof);
}
}
AssertInfo(HasRawData(field_id.get()), "Growing segment loss raw data");
}
template <typename S, typename T>

View File

@ -96,11 +96,6 @@ class SegmentGrowingImpl : public SegmentGrowing {
return chunk_mutex_;
}
const SealedIndexingRecord&
get_sealed_indexing_record() const {
return sealed_indexing_record_;
}
const Schema&
get_schema() const override {
return *schema_;
@ -180,6 +175,14 @@ class SegmentGrowingImpl : public SegmentGrowing {
int64_t count,
void* output_raw) const;
void
bulk_subscript_sparse_float_vector_impl(
FieldId field_id,
const ConcurrentVector<SparseFloatVector>* vec_raw,
const int64_t* seg_offsets,
int64_t count,
milvus::proto::schema::SparseFloatArray* output) const;
void
bulk_subscript(SystemFieldType system_type,
const int64_t* seg_offsets,
@ -292,7 +295,6 @@ class SegmentGrowingImpl : public SegmentGrowing {
// small indexes for every chunk
IndexingRecord indexing_record_;
SealedIndexingRecord sealed_indexing_record_; // not used
// inserted fields data and row_ids, timestamps
InsertRecord<false> insert_record_;

View File

@ -495,7 +495,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) {
update_row_count(num_rows);
}
if (generate_binlog_index(field_id)) {
if (generate_interim_index(field_id)) {
std::unique_lock lck(mutex_);
fields_.erase(field_id);
set_bit(field_data_ready_bitset_, field_id, false);
@ -848,65 +848,68 @@ SegmentSealedImpl::get_vector(FieldId field_id,
if (has_raw_data) {
// If index has raw data, get vector from memory.
auto ids_ds = GenIdsDataset(count, ids);
auto vector = vec_index->GetVector(ids_ds);
return segcore::CreateVectorDataArrayFrom(
vector.data(), count, field_meta);
} else {
// If index doesn't have raw data, get vector from chunk cache.
auto cc = storage::ChunkCacheSingleton::GetInstance().GetChunkCache();
// group by data_path
auto id_to_data_path =
std::unordered_map<std::int64_t,
std::tuple<std::string, int64_t>>{};
auto path_to_column =
std::unordered_map<std::string, std::shared_ptr<ColumnBase>>{};
for (auto i = 0; i < count; i++) {
const auto& tuple = GetFieldDataPath(field_id, ids[i]);
id_to_data_path.emplace(ids[i], tuple);
path_to_column.emplace(std::get<0>(tuple), nullptr);
if (field_meta.get_data_type() == DataType::VECTOR_SPARSE_FLOAT) {
auto res = vec_index->GetSparseVector(ids_ds);
return segcore::CreateVectorDataArrayFrom(
res.get(), count, field_meta);
} else {
// dense vector:
auto vector = vec_index->GetVector(ids_ds);
return segcore::CreateVectorDataArrayFrom(
vector.data(), count, field_meta);
}
// read and prefetch
auto& pool =
ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::HIGH);
std::vector<
std::future<std::tuple<std::string, std::shared_ptr<ColumnBase>>>>
futures;
futures.reserve(path_to_column.size());
for (const auto& iter : path_to_column) {
const auto& data_path = iter.first;
futures.emplace_back(
pool.Submit(ReadFromChunkCache, cc, data_path));
}
for (int i = 0; i < futures.size(); ++i) {
const auto& [data_path, column] = futures[i].get();
path_to_column[data_path] = column;
}
// assign to data array
auto row_bytes = field_meta.get_sizeof();
auto buf = std::vector<char>(count * row_bytes);
for (auto i = 0; i < count; i++) {
AssertInfo(id_to_data_path.count(ids[i]) != 0, "id not found");
const auto& [data_path, offset_in_binlog] =
id_to_data_path.at(ids[i]);
AssertInfo(path_to_column.count(data_path) != 0,
"column not found");
const auto& column = path_to_column.at(data_path);
AssertInfo(
offset_in_binlog * row_bytes < column->ByteSize(),
"column idx out of range, idx: {}, size: {}, data_path: {}",
offset_in_binlog * row_bytes,
column->ByteSize(),
data_path);
auto vector = &column->Data()[offset_in_binlog * row_bytes];
std::memcpy(buf.data() + i * row_bytes, vector, row_bytes);
}
return segcore::CreateVectorDataArrayFrom(
buf.data(), count, field_meta);
}
AssertInfo(field_meta.get_data_type() != DataType::VECTOR_SPARSE_FLOAT,
"index of sparse float vector is guaranteed to have raw data");
// If index doesn't have raw data, get vector from chunk cache.
auto cc = storage::ChunkCacheSingleton::GetInstance().GetChunkCache();
// group by data_path
auto id_to_data_path =
std::unordered_map<std::int64_t, std::tuple<std::string, int64_t>>{};
auto path_to_column =
std::unordered_map<std::string, std::shared_ptr<ColumnBase>>{};
for (auto i = 0; i < count; i++) {
const auto& tuple = GetFieldDataPath(field_id, ids[i]);
id_to_data_path.emplace(ids[i], tuple);
path_to_column.emplace(std::get<0>(tuple), nullptr);
}
// read and prefetch
auto& pool = ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::HIGH);
std::vector<
std::future<std::tuple<std::string, std::shared_ptr<ColumnBase>>>>
futures;
futures.reserve(path_to_column.size());
for (const auto& iter : path_to_column) {
const auto& data_path = iter.first;
futures.emplace_back(pool.Submit(ReadFromChunkCache, cc, data_path));
}
for (int i = 0; i < futures.size(); ++i) {
const auto& [data_path, column] = futures[i].get();
path_to_column[data_path] = column;
}
// assign to data array
auto row_bytes = field_meta.get_sizeof();
auto buf = std::vector<char>(count * row_bytes);
for (auto i = 0; i < count; i++) {
AssertInfo(id_to_data_path.count(ids[i]) != 0, "id not found");
const auto& [data_path, offset_in_binlog] = id_to_data_path.at(ids[i]);
AssertInfo(path_to_column.count(data_path) != 0, "column not found");
const auto& column = path_to_column.at(data_path);
AssertInfo(offset_in_binlog * row_bytes < column->ByteSize(),
"column idx out of range, idx: {}, size: {}, data_path: {}",
offset_in_binlog * row_bytes,
column->ByteSize(),
data_path);
auto vector = &column->Data()[offset_in_binlog * row_bytes];
std::memcpy(buf.data() + i * row_bytes, vector, row_bytes);
}
return segcore::CreateVectorDataArrayFrom(buf.data(), count, field_meta);
}
void
@ -1102,7 +1105,7 @@ SegmentSealedImpl::bulk_subscript_array_impl(
}
}
// for vector
// for dense vector
void
SegmentSealedImpl::bulk_subscript_impl(int64_t element_sizeof,
const void* src_raw,
@ -1250,7 +1253,6 @@ SegmentSealedImpl::get_raw_data(FieldId field_id,
->mutable_data());
break;
}
case DataType::VECTOR_FLOAT: {
bulk_subscript_impl(field_meta.get_sizeof(),
column->Data(),
@ -1289,6 +1291,21 @@ SegmentSealedImpl::get_raw_data(FieldId field_id,
ret->mutable_vectors()->mutable_binary_vector()->data());
break;
}
case DataType::VECTOR_SPARSE_FLOAT: {
auto rows = static_cast<const knowhere::sparse::SparseRow<float>*>(
static_cast<const void*>(column->Data()));
auto dst = ret->mutable_vectors()->mutable_sparse_float_vector();
SparseRowsToProto(
[&](size_t i) {
auto offset = seg_offsets[i];
return offset != INVALID_SEG_OFFSET ? (rows + offset)
: nullptr;
},
count,
dst);
ret->mutable_vectors()->set_dim(dst->dim());
break;
}
default: {
PanicInfo(DataTypeInvalid,
@ -1519,7 +1536,7 @@ SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk,
}
bool
SegmentSealedImpl::generate_binlog_index(const FieldId field_id) {
SegmentSealedImpl::generate_interim_index(const FieldId field_id) {
if (col_index_meta_ == nullptr || !col_index_meta_->HasFiled(field_id)) {
return false;
}

View File

@ -267,7 +267,7 @@ class SegmentSealedImpl : public SegmentSealed {
WarmupChunkCache(const FieldId field_id) override;
bool
generate_binlog_index(const FieldId field_id);
generate_interim_index(const FieldId field_id);
private:
// segment loading state

View File

@ -315,8 +315,11 @@ CreateVectorDataArray(int64_t count, const FieldMeta& field_meta) {
field_meta.get_data_type()));
auto vector_array = data_array->mutable_vectors();
auto dim = field_meta.get_dim();
vector_array->set_dim(dim);
auto dim = 0;
if (data_type != DataType::VECTOR_SPARSE_FLOAT) {
dim = field_meta.get_dim();
vector_array->set_dim(dim);
}
switch (data_type) {
case DataType::VECTOR_FLOAT: {
auto length = count * dim;
@ -494,8 +497,12 @@ CreateVectorDataArrayFrom(const void* data_raw,
}
case DataType::VECTOR_SPARSE_FLOAT: {
SparseRowsToProto(
reinterpret_cast<const knowhere::sparse::SparseRow<float>*>(
data_raw),
[&](size_t i) {
return reinterpret_cast<
const knowhere::sparse::SparseRow<float>*>(
data_raw) +
i;
},
count,
vector_array->mutable_sparse_float_vector());
vector_array->set_dim(vector_array->sparse_float_vector().dim());
@ -541,8 +548,11 @@ MergeDataArray(
"merge field data type not consistent");
if (field_meta.is_vector()) {
auto vector_array = data_array->mutable_vectors();
auto dim = field_meta.get_dim();
vector_array->set_dim(dim);
auto dim = 0;
if (!datatype_is_sparse_vector(data_type)) {
dim = field_meta.get_dim();
vector_array->set_dim(dim);
}
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
auto data = VEC_FIELD_DATA(src_field_data, float).data();
auto obj = vector_array->mutable_float_vector();

View File

@ -306,7 +306,8 @@ LoadFieldRawData(CSegmentInterface c_segment,
auto field_meta = segment->get_schema()[milvus::FieldId(field_id)];
data_type = field_meta.get_data_type();
if (milvus::datatype_is_vector(data_type)) {
if (milvus::datatype_is_vector(data_type) &&
!milvus::datatype_is_sparse_vector(data_type)) {
dim = field_meta.get_dim();
}
}

View File

@ -18,6 +18,7 @@ add_definitions(-DMILVUS_TEST_SEGCORE_YAML_PATH="${CMAKE_SOURCE_DIR}/unittest/te
set(MILVUS_TEST_FILES
init_gtest.cpp
test_bf.cpp
test_bf_sparse.cpp
test_binary.cpp
test_bitmap.cpp
test_bool_index.cpp

View File

@ -23,13 +23,24 @@
#include "expr/ITypeExpr.h"
#include "plan/PlanNode.h"
TEST(Expr, AlwaysTrue) {
class ExprAlwaysTrueTest : public ::testing::TestWithParam<milvus::DataType> {};
INSTANTIATE_TEST_SUITE_P(
ExprAlwaysTrueParameters,
ExprAlwaysTrueTest,
::testing::Values(milvus::DataType::VECTOR_FLOAT,
milvus::DataType::VECTOR_SPARSE_FLOAT));
TEST_P(ExprAlwaysTrueTest, AlwaysTrue) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
auto data_type = GetParam();
auto metric_type = data_type == DataType::VECTOR_FLOAT
? knowhere::metric::L2
: knowhere::metric::IP;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i64_fid = schema->AddDebugField("age", DataType::INT64);
schema->set_primary_field_id(i64_fid);
@ -64,4 +75,4 @@ TEST(Expr, AlwaysTrue) {
auto val = age_col[i];
ASSERT_EQ(ans, true) << "@" << i << "!!" << val;
}
}
}

View File

@ -0,0 +1,115 @@
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <random>
#include "common/Utils.h"
#include "query/SearchBruteForce.h"
#include "test_utils/Constants.h"
#include "test_utils/Distance.h"
#include "test_utils/DataGen.h"
using namespace milvus;
using namespace milvus::segcore;
using namespace milvus::query;
namespace {
std::vector<int>
Ref(const knowhere::sparse::SparseRow<float>* base,
const knowhere::sparse::SparseRow<float>& query,
int nb,
int topk,
const knowhere::MetricType& metric) {
std::vector<std::tuple<float, int>> res;
for (int i = 0; i < nb; i++) {
auto& row = base[i];
auto distance = row.dot(query);
res.emplace_back(-distance, i);
}
std::sort(res.begin(), res.end());
std::vector<int> offsets;
for (int i = 0; i < topk; i++) {
auto [distance, offset] = res[i];
if (distance == 0) {
distance = std::numeric_limits<float>::quiet_NaN();
offset = -1;
}
offsets.push_back(offset);
}
return offsets;
}
void
AssertMatch(const std::vector<int>& expected, const int64_t* actual) {
for (int i = 0; i < expected.size(); i++) {
ASSERT_EQ(expected[i], actual[i]);
}
}
bool
is_supported_sparse_float_metric(const std::string& metric) {
return milvus::IsMetricType(metric, knowhere::metric::IP);
}
} // namespace
class TestSparseFloatSearchBruteForce : public ::testing::Test {
public:
void
Run(int nb, int nq, int topk, const knowhere::MetricType& metric_type) {
auto bitset = std::make_shared<BitsetType>();
bitset->resize(nb);
auto bitset_view = BitsetView(*bitset);
auto base = milvus::segcore::GenerateRandomSparseFloatVector(nb);
auto query = milvus::segcore::GenerateRandomSparseFloatVector(nq);
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
dataset::SearchDataset dataset{
metric_type, nq, topk, -1, kTestSparseDim, query.get()};
if (!is_supported_sparse_float_metric(metric_type)) {
ASSERT_ANY_THROW(BruteForceSearch(dataset,
base.get(),
nb,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT));
return;
}
auto result = BruteForceSearch(dataset,
base.get(),
nb,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
auto ref =
Ref(base.get(), *(query.get() + i), nb, topk, metric_type);
auto ans = result.get_seg_offsets() + i * topk;
AssertMatch(ref, ans);
}
}
};
TEST_F(TestSparseFloatSearchBruteForce, NotSupported) {
Run(100, 10, 5, "L2");
Run(100, 10, 5, "l2");
Run(100, 10, 5, "lxxx");
}
TEST_F(TestSparseFloatSearchBruteForce, IP) {
Run(100, 10, 5, "IP");
Run(100, 10, 5, "ip");
}

View File

@ -27,14 +27,13 @@ using namespace milvus;
using namespace milvus::segcore;
namespace pb = milvus::proto;
std::shared_ptr<float[]>
std::unique_ptr<float[]>
GenRandomFloatVecData(int rows, int dim, int seed = 42) {
std::shared_ptr<float[]> vecs =
std::shared_ptr<float[]>(new float[rows * dim]);
auto vecs = std::make_unique<float[]>(rows * dim);
std::mt19937 rng(seed);
std::uniform_int_distribution<> distrib(0.0, 100.0);
for (int i = 0; i < rows * dim; ++i) vecs[i] = (float)distrib(rng);
return std::move(vecs);
return vecs;
}
inline float
@ -60,27 +59,42 @@ GetKnnSearchRecall(
return ((float)matched_num) / ((float)nq * res_k);
}
using Param = const char*;
using Param =
std::tuple<DataType, knowhere::MetricType, /* IndexType */ std::string>;
class BinlogIndexTest : public ::testing::TestWithParam<Param> {
void
SetUp() override {
auto param = GetParam();
metricType = param;
std::tie(data_type, metric_type, index_type) = GetParam();
schema = std::make_shared<Schema>();
auto metric_type = metricType;
vec_field_id = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, data_d, metric_type);
vec_field_id =
schema->AddDebugField("fakevec", data_type, data_d, metric_type);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
vec_field_data = storage::CreateFieldData(data_type, data_d);
// generate vector field data
vec_data = GenRandomFloatVecData(data_n, data_d);
vec_field_data =
storage::CreateFieldData(DataType::VECTOR_FLOAT, data_d);
vec_field_data->FillFieldData(vec_data.get(), data_n);
if (data_type == DataType::VECTOR_FLOAT) {
auto vec_data = GenRandomFloatVecData(data_n, data_d);
vec_field_data->FillFieldData(vec_data.get(), data_n);
raw_dataset = knowhere::GenDataSet(data_n, data_d, vec_data.get());
raw_dataset->SetIsOwner(true);
vec_data.release();
} else if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
auto sparse_vecs = GenerateRandomSparseFloatVector(data_n);
vec_field_data->FillFieldData(sparse_vecs.get(), data_n);
data_d = std::dynamic_pointer_cast<
milvus::FieldData<milvus::SparseFloatVector>>(
vec_field_data)
->Dim();
raw_dataset =
knowhere::GenDataSet(data_n, data_d, sparse_vecs.get());
raw_dataset->SetIsOwner(true);
raw_dataset->SetIsSparse(true);
sparse_vecs.release();
} else {
throw std::runtime_error("not implemented");
}
}
public:
@ -88,7 +102,7 @@ class BinlogIndexTest : public ::testing::TestWithParam<Param> {
GetCollectionIndexMeta(std::string index_type) {
std::map<std::string, std::string> index_params = {
{"index_type", index_type},
{"metric_type", metricType},
{"metric_type", metric_type},
{"nlist", "1024"}};
std::map<std::string, std::string> type_params = {{"dim", "128"}};
FieldIndexMeta fieldIndexMeta(
@ -131,23 +145,34 @@ class BinlogIndexTest : public ::testing::TestWithParam<Param> {
protected:
milvus::SchemaPtr schema;
const char* metricType;
knowhere::MetricType metric_type;
DataType data_type;
std::string index_type;
size_t data_n = 10000;
size_t data_d = 128;
size_t topk = 10;
milvus::FieldDataPtr vec_field_data = nullptr;
milvus::segcore::SegmentSealedUPtr segment = nullptr;
milvus::FieldId vec_field_id;
std::shared_ptr<float[]> vec_data;
knowhere::DataSetPtr raw_dataset;
};
INSTANTIATE_TEST_SUITE_P(MetricTypeParameters,
BinlogIndexTest,
::testing::Values(knowhere::metric::L2));
INSTANTIATE_TEST_SUITE_P(
MetricTypeParameters,
BinlogIndexTest,
::testing::Values(
std::make_tuple(DataType::VECTOR_FLOAT,
knowhere::metric::L2,
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT),
std::make_tuple(DataType::VECTOR_SPARSE_FLOAT,
knowhere::metric::IP,
knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX),
std::make_tuple(DataType::VECTOR_SPARSE_FLOAT,
knowhere::metric::IP,
knowhere::IndexEnum::INDEX_SPARSE_WAND)));
TEST_P(BinlogIndexTest, Accuracy) {
IndexMetaPtr collection_index_meta =
GetCollectionIndexMeta(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
IndexMetaPtr collection_index_meta = GetCollectionIndexMeta(index_type);
segment = CreateSealedSegment(schema, collection_index_meta);
LoadOtherFields();
@ -159,6 +184,7 @@ TEST_P(BinlogIndexTest, Accuracy) {
auto field_data_info = FieldDataInfo{
vec_field_id.get(), data_n, std::vector<FieldDataPtr>{vec_field_data}};
segment->LoadFieldData(vec_field_id, field_data_info);
//assert segment has been built binlog index
EXPECT_TRUE(segment->HasIndex(vec_field_id));
EXPECT_EQ(segment->get_row_count(), data_n);
@ -166,7 +192,6 @@ TEST_P(BinlogIndexTest, Accuracy) {
// 2. search binlog index
auto num_queries = 10;
auto query_ptr = GenRandomFloatVecData(num_queries, data_d);
milvus::proto::plan::PlanNode plan_node;
auto vector_anns = plan_node.mutable_vector_anns();
@ -176,12 +201,17 @@ TEST_P(BinlogIndexTest, Accuracy) {
auto query_info = vector_anns->mutable_query_info();
query_info->set_topk(topk);
query_info->set_round_decimal(3);
query_info->set_metric_type(metricType);
query_info->set_metric_type(metric_type);
query_info->set_search_params(R"({"nprobe": 1024})");
auto plan_str = plan_node.SerializeAsString();
auto ph_group_raw =
CreatePlaceholderGroupFromBlob(num_queries, data_d, query_ptr.get());
data_type == DataType::VECTOR_FLOAT
? CreatePlaceholderGroupFromBlob(
num_queries,
data_d,
GenRandomFloatVecData(num_queries, data_d).get())
: CreateSparseFloatPlaceholderGroup(num_queries);
auto plan = milvus::query::CreateSearchPlanByExpr(
*schema, plan_str.data(), plan_str.size());
@ -201,27 +231,25 @@ TEST_P(BinlogIndexTest, Accuracy) {
// 3. update vector index
{
milvus::index::CreateIndexInfo create_index_info;
create_index_info.field_type = DataType::VECTOR_FLOAT;
create_index_info.metric_type = metricType;
create_index_info.index_type = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
create_index_info.field_type = data_type;
create_index_info.metric_type = metric_type;
create_index_info.index_type = index_type;
create_index_info.index_engine_version =
knowhere::Version::GetCurrentVersion().VersionNumber();
auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info, milvus::storage::FileManagerContext());
auto build_conf =
knowhere::Json{{knowhere::meta::METRIC_TYPE, metricType},
knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type},
{knowhere::meta::DIM, std::to_string(data_d)},
{knowhere::indexparam::NLIST, "1024"}};
auto database = knowhere::GenDataSet(data_n, data_d, vec_data.get());
indexing->BuildWithDataset(database, build_conf);
indexing->BuildWithDataset(raw_dataset, build_conf);
LoadIndexInfo load_info;
load_info.field_id = vec_field_id.get();
load_info.index = std::move(indexing);
load_info.index_params["metric_type"] = metricType;
load_info.index_params["metric_type"] = metric_type;
segment->DropFieldData(vec_field_id);
ASSERT_NO_THROW(segment->LoadIndex(load_info));
EXPECT_TRUE(segment->HasIndex(vec_field_id));
@ -238,8 +266,7 @@ TEST_P(BinlogIndexTest, Accuracy) {
}
TEST_P(BinlogIndexTest, DisableInterimIndex) {
IndexMetaPtr collection_index_meta =
GetCollectionIndexMeta(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
IndexMetaPtr collection_index_meta = GetCollectionIndexMeta(index_type);
segment = CreateSealedSegment(schema, collection_index_meta);
LoadOtherFields();
@ -254,27 +281,26 @@ TEST_P(BinlogIndexTest, DisableInterimIndex) {
EXPECT_TRUE(segment->HasFieldData(vec_field_id));
// load vector index
milvus::index::CreateIndexInfo create_index_info;
create_index_info.field_type = DataType::VECTOR_FLOAT;
create_index_info.metric_type = metricType;
create_index_info.index_type = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
create_index_info.field_type = data_type;
create_index_info.metric_type = metric_type;
create_index_info.index_type = index_type;
create_index_info.index_engine_version =
knowhere::Version::GetCurrentVersion().VersionNumber();
auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info, milvus::storage::FileManagerContext());
auto build_conf =
knowhere::Json{{knowhere::meta::METRIC_TYPE, metricType},
knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type},
{knowhere::meta::DIM, std::to_string(data_d)},
{knowhere::indexparam::NLIST, "1024"}};
auto database = knowhere::GenDataSet(data_n, data_d, vec_data.get());
indexing->BuildWithDataset(database, build_conf);
indexing->BuildWithDataset(raw_dataset, build_conf);
LoadIndexInfo load_info;
load_info.field_id = vec_field_id.get();
load_info.index = std::move(indexing);
load_info.index_params["metric_type"] = metricType;
load_info.index_params["metric_type"] = metric_type;
segment->DropFieldData(vec_field_id);
ASSERT_NO_THROW(segment->LoadIndex(load_info));

View File

@ -37,7 +37,7 @@ using namespace milvus::exec;
using namespace milvus::query;
using namespace milvus::segcore;
class TaskTest : public testing::Test {
class TaskTest : public testing::TestWithParam<DataType> {
protected:
void
SetUp() override {
@ -46,7 +46,7 @@ class TaskTest : public testing::Test {
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
"fakevec", GetParam(), 16, knowhere::metric::L2);
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
field_map_.insert({"bool", bool_fid});
auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL);
@ -112,7 +112,12 @@ class TaskTest : public testing::Test {
int64_t num_rows_{0};
};
TEST_F(TaskTest, UnaryExpr) {
INSTANTIATE_TEST_SUITE_P(TaskTestSuite,
TaskTest,
::testing::Values(DataType::VECTOR_FLOAT,
DataType::VECTOR_SPARSE_FLOAT));
TEST_P(TaskTest, UnaryExpr) {
::milvus::proto::plan::GenericValue value;
value.set_int64_val(-1);
auto logical_expr = std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
@ -149,7 +154,7 @@ TEST_F(TaskTest, UnaryExpr) {
EXPECT_EQ(num_rows, num_rows_);
}
TEST_F(TaskTest, LogicalExpr) {
TEST_P(TaskTest, LogicalExpr) {
::milvus::proto::plan::GenericValue value;
value.set_int64_val(-1);
auto left = std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
@ -193,13 +198,13 @@ TEST_F(TaskTest, LogicalExpr) {
EXPECT_EQ(num_rows, num_rows_);
}
TEST_F(TaskTest, CompileInputs_and) {
TEST_P(TaskTest, CompileInputs_and) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid =
schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
proto::plan::GenericValue val;
val.set_int64_val(10);
@ -236,13 +241,13 @@ TEST_F(TaskTest, CompileInputs_and) {
}
}
TEST_F(TaskTest, CompileInputs_or_with_and) {
TEST_P(TaskTest, CompileInputs_or_with_and) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid =
schema->AddDebugField("fakevec", GetParam(), 16, knowhere::metric::L2);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
proto::plan::GenericValue val;
val.set_int64_val(10);

View File

@ -39,37 +39,43 @@ using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
TEST(Expr, Range) {
SUCCEED();
// std::string dsl_string = R"({
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// "GT": 1,
// "LT": 100
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
class ExprTest : public ::testing::TestWithParam<
std::pair<milvus::DataType, knowhere::MetricType>> {
public:
void
SetUp() override {
auto param = GetParam();
data_type = param.first;
metric_type = param.second;
}
const char* raw_plan = R"(vector_anns: <
// replace the metric type in the plan string with the proper type
std::vector<char>
translate_text_plan_with_metric_type(std::string plan) {
return milvus::segcore::
replace_metric_and_translate_text_plan_to_binary_plan(
std::move(plan), metric_type);
}
milvus::DataType data_type;
knowhere::MetricType metric_type;
};
INSTANTIATE_TEST_SUITE_P(
ExprTestSuite,
ExprTest,
::testing::Values(
std::pair(milvus::DataType::VECTOR_FLOAT, knowhere::metric::L2),
std::pair(milvus::DataType::VECTOR_SPARSE_FLOAT, knowhere::metric::IP),
std::pair(milvus::DataType::VECTOR_BINARY, knowhere::metric::JACCARD)));
TEST_P(ExprTest, Range) {
SUCCEED();
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
std::string raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_expr: <
@ -108,10 +114,9 @@ TEST(Expr, Range) {
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto schema = std::make_shared<Schema>();
schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
schema->AddDebugField("fakevec", data_type, 16, metric_type);
schema->AddDebugField("age", DataType::INT32);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
@ -120,116 +125,9 @@ TEST(Expr, Range) {
schema->get_field_id(FieldName("fakevec")));
}
TEST(Expr, RangeBinary) {
TEST_P(ExprTest, InvalidRange) {
SUCCEED();
// std::string dsl_string = R"({
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// "GT": 1,
// "LT": 100
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "Jaccard",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_expr: <
op: LogicalAnd
left: <
unary_range_expr: <
column_info: <
field_id: 101
data_type: Int32
>
op: GreaterThan
value: <
int64_val: 1
>
>
>
right: <
unary_range_expr: <
column_info: <
field_id: 101
data_type: Int32
>
op: LessThan
value: <
int64_val: 100
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "JACCARD"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto schema = std::make_shared<Schema>();
schema->AddDebugField(
"fakevec", DataType::VECTOR_BINARY, 512, knowhere::metric::JACCARD);
schema->AddDebugField("age", DataType::INT32);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
ShowPlanNodeVisitor shower;
Assert(plan->tag2field_.at("$0") ==
schema->get_field_id(FieldName("fakevec")));
}
TEST(Expr, InvalidRange) {
SUCCEED();
// std::string dsl_string = R"(
// {
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// "GT": 1,
// "LT": "100"
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10
// }
// }
// }
// ]
// }
// })";
const char* raw_plan = R"(vector_anns: <
std::string raw_plan = R"(vector_anns: <
field_id: 100
predicates: <
binary_expr: <
@ -268,21 +166,19 @@ TEST(Expr, InvalidRange) {
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto schema = std::make_shared<Schema>();
schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
schema->AddDebugField("fakevec", data_type, 16, metric_type);
schema->AddDebugField("age", DataType::INT32);
ASSERT_ANY_THROW(
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()));
}
TEST(Expr, ShowExecutor) {
TEST_P(ExprTest, ShowExecutor) {
auto node = std::make_unique<FloatVectorANNS>();
auto schema = std::make_shared<Schema>();
auto metric_type = knowhere::metric::L2;
auto field_id = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, metric_type);
auto field_id =
schema->AddDebugField("fakevec", data_type, 16, metric_type);
int64_t num_queries = 100L;
auto raw_data = DataGen(schema, num_queries);
auto& info = node->search_info_;
@ -299,7 +195,7 @@ TEST(Expr, ShowExecutor) {
std::cout << dup.dump(4);
}
TEST(Expr, TestRange) {
TEST_P(ExprTest, TestRange) {
std::vector<std::tuple<std::string, std::function<bool(int)>>> testcases = {
{R"(binary_range_expr: <
column_info: <
@ -429,32 +325,6 @@ TEST(Expr, TestRange) {
[](int v) { return v != 2000; }},
};
// std::string dsl_string_tmp = R"({
// "bool": {
// "must": [
// {
// "range": {
// "age": {
// @@@@
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
std::string raw_plan_tmp = R"(vector_anns: <
field_id: 100
predicates: <
@ -469,8 +339,7 @@ TEST(Expr, TestRange) {
placeholder_tag: "$0"
>)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i64_fid = schema->AddDebugField("age", DataType::INT64);
schema->set_primary_field_id(i64_fid);
@ -496,7 +365,7 @@ TEST(Expr, TestRange) {
auto loc = raw_plan_tmp.find("@@@@");
auto raw_plan = raw_plan_tmp;
raw_plan.replace(loc, 4, clause);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP);
@ -517,7 +386,7 @@ TEST(Expr, TestRange) {
}
}
TEST(Expr, TestBinaryRangeJSON) {
TEST_P(ExprTest, TestBinaryRangeJSON) {
struct Testcase {
bool lower_inclusive;
bool upper_inclusive;
@ -616,7 +485,7 @@ TEST(Expr, TestBinaryRangeJSON) {
}
}
TEST(Expr, TestExistsJson) {
TEST_P(ExprTest, TestExistsJson) {
struct Testcase {
std::vector<std::string> nested_path;
};
@ -707,7 +576,7 @@ GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) {
}
};
TEST(Expr, TestUnaryRangeJson) {
TEST_P(ExprTest, TestUnaryRangeJson) {
struct Testcase {
int64_t val;
std::vector<std::string> nested_path;
@ -876,7 +745,7 @@ TEST(Expr, TestUnaryRangeJson) {
}
}
TEST(Expr, TestTermJson) {
TEST_P(ExprTest, TestTermJson) {
struct Testcase {
std::vector<int64_t> term;
std::vector<std::string> nested_path;
@ -947,7 +816,7 @@ TEST(Expr, TestTermJson) {
}
}
TEST(Expr, TestTerm) {
TEST_P(ExprTest, TestTerm) {
auto vec_2k_3k = [] {
std::string buf;
for (int i = 2000; i < 3000; ++i) {
@ -977,33 +846,6 @@ TEST(Expr, TestTerm) {
{vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }},
};
// std::string dsl_string_tmp = R"({
// "bool": {
// "must": [
// {
// "term": {
// "age": {
// "values": @@@@,
// "is_in_field" : false
// }
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
std::string raw_plan_tmp = R"(vector_anns: <
field_id: 100
predicates: <
@ -1024,8 +866,7 @@ TEST(Expr, TestTerm) {
placeholder_tag: "$0"
>)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i64_fid = schema->AddDebugField("age", DataType::INT64);
schema->set_primary_field_id(i64_fid);
@ -1051,7 +892,7 @@ TEST(Expr, TestTerm) {
auto loc = raw_plan_tmp.find("@@@@");
auto raw_plan = raw_plan_tmp;
raw_plan.replace(loc, 4, clause);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
BitsetType final;
@ -1071,7 +912,7 @@ TEST(Expr, TestTerm) {
}
}
TEST(Expr, TestCompare) {
TEST_P(ExprTest, TestCompare) {
std::vector<std::tuple<std::string, std::function<bool(int, int64_t)>>>
testcases = {
{R"(LessThan)", [](int a, int64_t b) { return a < b; }},
@ -1082,33 +923,6 @@ TEST(Expr, TestCompare) {
{R"(NotEqual)", [](int a, int64_t b) { return a != b; }},
};
// std::string dsl_string_tpl = R"({
// "bool": {
// "must": [
// {
// "compare": {
// %1%: [
// "age1",
// "age2"
// ]
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
std::string raw_plan_tmp = R"(vector_anns: <
field_id: 100
predicates: <
@ -1133,8 +947,7 @@ TEST(Expr, TestCompare) {
placeholder_tag: "$0"
>)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i32_fid = schema->AddDebugField("age1", DataType::INT32);
auto i64_fid = schema->AddDebugField("age2", DataType::INT64);
schema->set_primary_field_id(i64_fid);
@ -1166,7 +979,7 @@ TEST(Expr, TestCompare) {
auto loc = raw_plan_tmp.find("@@@@");
auto raw_plan = raw_plan_tmp;
raw_plan.replace(loc, 4, clause);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
BitsetType final;
@ -1188,7 +1001,7 @@ TEST(Expr, TestCompare) {
}
}
TEST(Expr, TestCompareWithScalarIndex) {
TEST_P(ExprTest, TestCompareWithScalarIndex) {
std::vector<std::tuple<std::string, std::function<bool(int, int64_t)>>>
testcases = {
{R"(LessThan)", [](int a, int64_t b) { return a < b; }},
@ -1224,8 +1037,7 @@ TEST(Expr, TestCompareWithScalarIndex) {
>)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i32_fid = schema->AddDebugField("age32", DataType::INT32);
auto i64_fid = schema->AddDebugField("age64", DataType::INT64);
schema->set_primary_field_id(i64_fid);
@ -1264,7 +1076,7 @@ TEST(Expr, TestCompareWithScalarIndex) {
i32_fid.get() % proto::schema::DataType_Name(int(DataType::INT32)) %
i64_fid.get() % proto::schema::DataType_Name(int(DataType::INT64));
auto binary_plan =
translate_text_plan_to_binary_plan(dsl_string.str().data());
translate_text_plan_with_metric_type(dsl_string.str());
auto plan = CreateSearchPlanByExpr(
*schema, binary_plan.data(), binary_plan.size());
// std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl;
@ -1284,10 +1096,9 @@ TEST(Expr, TestCompareWithScalarIndex) {
}
}
TEST(Expr, TestCompareExpr) {
TEST_P(ExprTest, TestCompareExpr) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
@ -1433,10 +1244,9 @@ TEST(Expr, TestCompareExpr) {
std::cout << "end compare test" << std::endl;
}
TEST(Expr, TestMultiLogicalExprsOptimization) {
TEST_P(ExprTest, TestMultiLogicalExprsOptimization) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
schema->set_primary_field_id(str1_fid);
@ -1519,10 +1329,9 @@ TEST(Expr, TestMultiLogicalExprsOptimization) {
ASSERT_LT(cost_op, cost_no_op);
}
TEST(Expr, TestExprs) {
TEST_P(ExprTest, TestExprs) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -1691,11 +1500,10 @@ TEST(Expr, TestExprs) {
// test_case(500);
}
TEST(Expr, test_term_pk) {
TEST_P(ExprTest, test_term_pk) {
auto schema = std::make_shared<Schema>();
schema->AddField(FieldName("Timestamp"), FieldId(1), DataType::INT64);
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
schema->set_primary_field_id(int64_fid);
@ -1755,10 +1563,9 @@ TEST(Expr, test_term_pk) {
}
}
TEST(Expr, TestSealedSegmentGetBatchSize) {
TEST_P(ExprTest, TestSealedSegmentGetBatchSize) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
schema->set_primary_field_id(str1_fid);
@ -1817,10 +1624,9 @@ TEST(Expr, TestSealedSegmentGetBatchSize) {
}
}
TEST(Expr, TestGrowingSegmentGetBatchSize) {
TEST_P(ExprTest, TestGrowingSegmentGetBatchSize) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
schema->set_primary_field_id(str1_fid);
@ -1873,10 +1679,9 @@ TEST(Expr, TestGrowingSegmentGetBatchSize) {
}
}
TEST(Expr, TestConjuctExpr) {
TEST_P(ExprTest, TestConjuctExpr) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -1941,10 +1746,9 @@ TEST(Expr, TestConjuctExpr) {
}
}
TEST(Expr, TestUnaryBenchTest) {
TEST_P(ExprTest, TestUnaryBenchTest) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -2013,10 +1817,9 @@ TEST(Expr, TestUnaryBenchTest) {
}
}
TEST(Expr, TestBinaryRangeBenchTest) {
TEST_P(ExprTest, TestBinaryRangeBenchTest) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -2094,10 +1897,9 @@ TEST(Expr, TestBinaryRangeBenchTest) {
}
}
TEST(Expr, TestLogicalUnaryBenchTest) {
TEST_P(ExprTest, TestLogicalUnaryBenchTest) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -2169,10 +1971,9 @@ TEST(Expr, TestLogicalUnaryBenchTest) {
}
}
TEST(Expr, TestBinaryLogicalBenchTest) {
TEST_P(ExprTest, TestBinaryLogicalBenchTest) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -2254,10 +2055,9 @@ TEST(Expr, TestBinaryLogicalBenchTest) {
}
}
TEST(Expr, TestBinaryArithOpEvalRangeBenchExpr) {
TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -2335,10 +2135,9 @@ TEST(Expr, TestBinaryArithOpEvalRangeBenchExpr) {
}
}
TEST(Expr, TestCompareExprBenchTest) {
TEST_P(ExprTest, TestCompareExprBenchTest) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -2409,10 +2208,9 @@ TEST(Expr, TestCompareExprBenchTest) {
}
}
TEST(Expr, TestRefactorExprs) {
TEST_P(ExprTest, TestRefactorExprs) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
@ -2579,7 +2377,7 @@ TEST(Expr, TestRefactorExprs) {
// test_case(500);
}
TEST(Expr, TestCompareWithScalarIndexMaris) {
TEST_P(ExprTest, TestCompareWithScalarIndexMaris) {
std::vector<
std::tuple<std::string, std::function<bool(std::string, std::string)>>>
testcases = {
@ -2597,7 +2395,7 @@ TEST(Expr, TestCompareWithScalarIndexMaris) {
[](std::string a, std::string b) { return a.compare(b) != 0; }},
};
const char* serialized_expr_plan = R"(vector_anns: <
std::string serialized_expr_plan = R"(vector_anns: <
field_id: %1%
predicates: <
compare_expr: <
@ -2622,8 +2420,7 @@ TEST(Expr, TestCompareWithScalarIndexMaris) {
>)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR);
auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR);
schema->set_primary_field_id(str1_fid);
@ -2658,7 +2455,7 @@ TEST(Expr, TestCompareWithScalarIndexMaris) {
auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() %
clause % str1_fid.get() % str2_fid.get();
auto binary_plan =
translate_text_plan_to_binary_plan(dsl_string.str().data());
translate_text_plan_with_metric_type(dsl_string.str());
auto plan = CreateSearchPlanByExpr(
*schema, binary_plan.data(), binary_plan.size());
// std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl;
@ -2678,7 +2475,7 @@ TEST(Expr, TestCompareWithScalarIndexMaris) {
}
}
TEST(Expr, TestBinaryArithOpEvalRange) {
TEST_P(ExprTest, TestBinaryArithOpEvalRange) {
std::vector<std::tuple<std::string, std::function<bool(int)>, DataType>> testcases = {
// Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types
{R"(binary_arith_op_eval_range_expr: <
@ -3280,31 +3077,6 @@ TEST(Expr, TestBinaryArithOpEvalRange) {
DataType::INT64},
};
// std::string dsl_string_tmp = R"({
// "bool": {
// "must": [
// {
// "range": {
// @@@@@
// }
// },
// {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// ]
// }
// })";
std::string raw_plan_tmp = R"(vector_anns: <
field_id: 100
predicates: <
@ -3319,8 +3091,7 @@ TEST(Expr, TestBinaryArithOpEvalRange) {
placeholder_tag: "$0"
>)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i8_fid = schema->AddDebugField("age8", DataType::INT8);
auto i16_fid = schema->AddDebugField("age16", DataType::INT16);
auto i32_fid = schema->AddDebugField("age32", DataType::INT32);
@ -3394,7 +3165,7 @@ TEST(Expr, TestBinaryArithOpEvalRange) {
// }
// loc = dsl_string.find("@@@@");
// dsl_string.replace(loc, 4, clause);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
BitsetType final;
@ -3438,7 +3209,7 @@ TEST(Expr, TestBinaryArithOpEvalRange) {
}
}
TEST(Expr, TestBinaryArithOpEvalRangeJSON) {
TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) {
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
@ -4250,7 +4021,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) {
}
}
TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) {
TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) {
struct Testcase {
double right_operand;
double value;
@ -4376,7 +4147,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) {
}
}
TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) {
TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) {
std::vector<std::tuple<std::string, std::function<bool(int)>, DataType>>
testcases = {
// Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types
@ -4744,8 +4515,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) {
@@@@)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i8_fid = schema->AddDebugField("age8", DataType::INT8);
auto i16_fid = schema->AddDebugField("age16", DataType::INT16);
auto i32_fid = schema->AddDebugField("age32", DataType::INT32);
@ -4857,8 +4627,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) {
ASSERT_TRUE(false) << "No test case defined for this data type";
}
auto binary_plan =
translate_text_plan_to_binary_plan(expr.str().data());
auto binary_plan = translate_text_plan_with_metric_type(expr.str());
auto plan = CreateSearchPlanByExpr(
*schema, binary_plan.data(), binary_plan.size());
@ -4900,7 +4669,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) {
}
}
TEST(Expr, TestUnaryRangeWithJSON) {
TEST_P(ExprTest, TestUnaryRangeWithJSON) {
std::vector<
std::tuple<std::string,
std::function<bool(
@ -4990,8 +4759,7 @@ TEST(Expr, TestUnaryRangeWithJSON) {
@@@@)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i64_fid = schema->AddDebugField("age64", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
schema->set_primary_field_id(i64_fid);
@ -5056,7 +4824,7 @@ TEST(Expr, TestUnaryRangeWithJSON) {
}
}
auto unary_plan = translate_text_plan_to_binary_plan(expr.str().data());
auto unary_plan = translate_text_plan_with_metric_type(expr.str());
auto plan = CreateSearchPlanByExpr(
*schema, unary_plan.data(), unary_plan.size());
@ -5100,7 +4868,7 @@ TEST(Expr, TestUnaryRangeWithJSON) {
}
}
TEST(Expr, TestTermWithJSON) {
TEST_P(ExprTest, TestTermWithJSON) {
std::vector<
std::tuple<std::string,
std::function<bool(
@ -5168,8 +4936,7 @@ TEST(Expr, TestTermWithJSON) {
@@@@)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i64_fid = schema->AddDebugField("age64", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
schema->set_primary_field_id(i64_fid);
@ -5234,7 +5001,7 @@ TEST(Expr, TestTermWithJSON) {
}
}
auto unary_plan = translate_text_plan_to_binary_plan(expr.str().data());
auto unary_plan = translate_text_plan_with_metric_type(expr.str());
auto plan = CreateSearchPlanByExpr(
*schema, unary_plan.data(), unary_plan.size());
@ -5278,7 +5045,7 @@ TEST(Expr, TestTermWithJSON) {
}
}
TEST(Expr, TestExistsWithJSON) {
TEST_P(ExprTest, TestExistsWithJSON) {
std::vector<std::tuple<std::string, std::function<bool(bool)>, DataType>>
testcases = {
{R"()", [](bool v) { return v; }, DataType::BOOL},
@ -5313,8 +5080,7 @@ TEST(Expr, TestExistsWithJSON) {
@@@@)";
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type);
auto i64_fid = schema->AddDebugField("age64", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
schema->set_primary_field_id(i64_fid);
@ -5386,7 +5152,7 @@ TEST(Expr, TestExistsWithJSON) {
}
}
auto unary_plan = translate_text_plan_to_binary_plan(expr.str().data());
auto unary_plan = translate_text_plan_with_metric_type(expr.str());
auto plan = CreateSearchPlanByExpr(
*schema, unary_plan.data(), unary_plan.size());
@ -5438,7 +5204,7 @@ struct Testcase {
bool res;
};
TEST(Expr, TestTermInFieldJson) {
TEST_P(ExprTest, TestTermInFieldJson) {
auto schema = std::make_shared<Schema>();
auto i64_fid = schema->AddDebugField("id", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
@ -5654,8 +5420,8 @@ TEST(Expr, TestTermInFieldJson) {
}
}
TEST(Expr, PraseJsonContainsExpr) {
std::vector<const char*> raw_plans{
TEST_P(ExprTest, PraseJsonContainsExpr) {
std::vector<std::string> raw_plans{
R"(vector_anns:<
field_id:100
predicates:<
@ -5787,17 +5553,16 @@ TEST(Expr, PraseJsonContainsExpr) {
};
for (auto& raw_plan : raw_plans) {
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan_str = translate_text_plan_with_metric_type(raw_plan);
auto schema = std::make_shared<Schema>();
schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
schema->AddDebugField("fakevec", data_type, 16, metric_type);
schema->AddDebugField("json", DataType::JSON);
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
}
}
TEST(Expr, TestJsonContainsAny) {
TEST_P(ExprTest, TestJsonContainsAny) {
auto schema = std::make_shared<Schema>();
auto i64_fid = schema->AddDebugField("id", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
@ -6017,7 +5782,7 @@ TEST(Expr, TestJsonContainsAny) {
}
}
TEST(Expr, TestJsonContainsAll) {
TEST_P(ExprTest, TestJsonContainsAll) {
auto schema = std::make_shared<Schema>();
auto i64_fid = schema->AddDebugField("id", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
@ -6261,7 +6026,7 @@ TEST(Expr, TestJsonContainsAll) {
}
}
TEST(Expr, TestJsonContainsArray) {
TEST_P(ExprTest, TestJsonContainsArray) {
auto schema = std::make_shared<Schema>();
auto i64_fid = schema->AddDebugField("id", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
@ -6588,7 +6353,7 @@ generatedArrayWithFourDiffType(int64_t int_val,
return value;
}
TEST(Expr, TestJsonContainsDiffTypeArray) {
TEST_P(ExprTest, TestJsonContainsDiffTypeArray) {
auto schema = std::make_shared<Schema>();
auto i64_fid = schema->AddDebugField("id", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
@ -6690,7 +6455,7 @@ TEST(Expr, TestJsonContainsDiffTypeArray) {
}
}
TEST(Expr, TestJsonContainsDiffType) {
TEST_P(ExprTest, TestJsonContainsDiffType) {
auto schema = std::make_shared<Schema>();
auto i64_fid = schema->AddDebugField("id", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);

View File

@ -97,9 +97,50 @@ TEST(Growing, RealCount) {
ASSERT_EQ(0, segment->get_real_count());
}
TEST(Growing, FillData) {
class GrowingTest
: public ::testing::TestWithParam<
std::tuple</*index type*/ std::string, knowhere::MetricType>> {
public:
void
SetUp() override {
auto index_type = std::get<0>(GetParam());
auto metric_type = std::get<1>(GetParam());
if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT ||
index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC) {
data_type = DataType::VECTOR_FLOAT;
} else if (index_type ==
knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
data_type = DataType::VECTOR_SPARSE_FLOAT;
} else {
ASSERT_TRUE(false);
}
}
knowhere::MetricType metric_type;
std::string index_type;
DataType data_type;
};
INSTANTIATE_TEST_SUITE_P(
FloatGrowingTest,
GrowingTest,
::testing::Combine(
::testing::Values(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC),
::testing::Values(knowhere::metric::L2,
knowhere::metric::IP,
knowhere::metric::COSINE)));
INSTANTIATE_TEST_SUITE_P(
SparseFloatGrowingTest,
GrowingTest,
::testing::Combine(
::testing::Values(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX,
knowhere::IndexEnum::INDEX_SPARSE_WAND),
::testing::Values(knowhere::metric::IP)));
TEST_P(GrowingTest, FillData) {
auto schema = std::make_shared<Schema>();
auto metric_type = knowhere::metric::L2;
auto bool_field = schema->AddDebugField("bool", DataType::BOOL);
auto int8_field = schema->AddDebugField("int8", DataType::INT8);
auto int16_field = schema->AddDebugField("int16", DataType::INT16);
@ -121,12 +162,11 @@ TEST(Growing, FillData) {
"double_array", DataType::ARRAY, DataType::DOUBLE);
auto float_array_field =
schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT);
auto vec = schema->AddDebugField(
"embeddings", DataType::VECTOR_FLOAT, 128, metric_type);
auto vec = schema->AddDebugField("embeddings", data_type, 128, metric_type);
schema->set_primary_field_id(int64_field);
std::map<std::string, std::string> index_params = {
{"index_type", "IVF_FLAT"},
{"index_type", index_type},
{"metric_type", metric_type},
{"nlist", "128"}};
std::map<std::string, std::string> type_params = {{"dim", "128"}};
@ -146,25 +186,6 @@ TEST(Growing, FillData) {
int64_t dim = 128;
for (int64_t i = 0; i < n_batch; i++) {
auto dataset = DataGen(schema, per_batch);
auto bool_values = dataset.get_col<bool>(bool_field);
auto int8_values = dataset.get_col<int8_t>(int8_field);
auto int16_values = dataset.get_col<int16_t>(int16_field);
auto int32_values = dataset.get_col<int32_t>(int32_field);
auto int64_values = dataset.get_col<int64_t>(int64_field);
auto float_values = dataset.get_col<float>(float_field);
auto double_values = dataset.get_col<double>(double_field);
auto varchar_values = dataset.get_col<std::string>(varchar_field);
auto json_values = dataset.get_col<std::string>(json_field);
auto int_array_values = dataset.get_col<ScalarArray>(int_array_field);
auto long_array_values = dataset.get_col<ScalarArray>(long_array_field);
auto bool_array_values = dataset.get_col<ScalarArray>(bool_array_field);
auto string_array_values =
dataset.get_col<ScalarArray>(string_array_field);
auto double_array_values =
dataset.get_col<ScalarArray>(double_array_field);
auto float_array_values =
dataset.get_col<ScalarArray>(float_array_field);
auto vector_values = dataset.get_col<float>(vec);
auto offset = segment->PreInsert(per_batch);
segment->Insert(offset,
@ -220,8 +241,16 @@ TEST(Growing, FillData) {
EXPECT_EQ(varchar_result->scalars().string_data().data_size(),
num_inserted);
EXPECT_EQ(json_result->scalars().json_data().data_size(), num_inserted);
EXPECT_EQ(vec_result->vectors().float_vector().data_size(),
num_inserted * dim);
if (data_type == DataType::VECTOR_FLOAT) {
EXPECT_EQ(vec_result->vectors().float_vector().data_size(),
num_inserted * dim);
} else if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
EXPECT_EQ(
vec_result->vectors().sparse_float_vector().contents_size(),
num_inserted);
} else {
ASSERT_TRUE(false);
}
EXPECT_EQ(int_array_result->scalars().array_data().data_size(),
num_inserted);
EXPECT_EQ(long_array_result->scalars().array_data().data_size(),

View File

@ -11,9 +11,11 @@
#include <gtest/gtest.h>
#include "common/Utils.h"
#include "pb/plan.pb.h"
#include "pb/schema.pb.h"
#include "query/Plan.h"
#include "segcore/ConcurrentVector.h"
#include "segcore/SegmentGrowing.h"
#include "segcore/SegmentGrowingImpl.h"
#include "test_utils/DataGen.h"
@ -22,16 +24,63 @@ using namespace milvus;
using namespace milvus::segcore;
namespace pb = milvus::proto;
TEST(GrowingIndex, Correctness) {
using Param = std::tuple</*index type*/ std::string, knowhere::MetricType>;
class GrowingIndexTest : public ::testing::TestWithParam<Param> {
void
SetUp() override {
auto param = GetParam();
index_type = std::get<0>(param);
metric_type = std::get<1>(param);
if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT ||
index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC) {
data_type = DataType::VECTOR_FLOAT;
} else if (index_type ==
knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
data_type = DataType::VECTOR_SPARSE_FLOAT;
is_sparse = true;
} else {
ASSERT_TRUE(false);
}
}
protected:
std::string index_type;
knowhere::MetricType metric_type;
DataType data_type;
bool is_sparse = false;
};
INSTANTIATE_TEST_SUITE_P(
FloatIndexTypeParameters,
GrowingIndexTest,
::testing::Combine(
::testing::Values(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC),
::testing::Values(knowhere::metric::L2,
knowhere::metric::COSINE,
knowhere::metric::IP)));
INSTANTIATE_TEST_SUITE_P(
SparseIndexTypeParameters,
GrowingIndexTest,
::testing::Combine(
::testing::Values(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX,
knowhere::IndexEnum::INDEX_SPARSE_WAND),
::testing::Values(knowhere::metric::IP)));
TEST_P(GrowingIndexTest, Correctness) {
auto schema = std::make_shared<Schema>();
auto pk = schema->AddDebugField("pk", DataType::INT64);
auto random = schema->AddDebugField("random", DataType::DOUBLE);
auto vec = schema->AddDebugField(
"embeddings", DataType::VECTOR_FLOAT, 128, knowhere::metric::L2);
auto vec = schema->AddDebugField("embeddings", data_type, 128, metric_type);
schema->set_primary_field_id(pk);
std::map<std::string, std::string> index_params = {
{"index_type", "IVF_FLAT"}, {"metric_type", "L2"}, {"nlist", "128"}};
{"index_type", index_type},
{"metric_type", metric_type},
{"nlist", "128"}};
std::map<std::string, std::string> type_params = {{"dim", "128"}};
FieldIndexMeta fieldIndexMeta(
vec, std::move(index_params), std::move(type_params));
@ -46,28 +95,44 @@ TEST(GrowingIndex, Correctness) {
milvus::proto::plan::PlanNode plan_node;
auto vector_anns = plan_node.mutable_vector_anns();
vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector);
if (is_sparse) {
vector_anns->set_vector_type(
milvus::proto::plan::VectorType::SparseFloatVector);
} else {
vector_anns->set_vector_type(
milvus::proto::plan::VectorType::FloatVector);
}
vector_anns->set_placeholder_tag("$0");
vector_anns->set_field_id(102);
auto query_info = vector_anns->mutable_query_info();
query_info->set_topk(5);
query_info->set_round_decimal(3);
query_info->set_metric_type("l2");
query_info->set_metric_type(metric_type);
query_info->set_search_params(R"({"nprobe": 16})");
auto plan_str = plan_node.SerializeAsString();
milvus::proto::plan::PlanNode range_query_plan_node;
auto vector_range_querys = range_query_plan_node.mutable_vector_anns();
vector_range_querys->set_vector_type(
milvus::proto::plan::VectorType::FloatVector);
if (is_sparse) {
vector_range_querys->set_vector_type(
milvus::proto::plan::VectorType::SparseFloatVector);
} else {
vector_range_querys->set_vector_type(
milvus::proto::plan::VectorType::FloatVector);
}
vector_range_querys->set_placeholder_tag("$0");
vector_range_querys->set_field_id(102);
auto range_query_info = vector_range_querys->mutable_query_info();
range_query_info->set_topk(5);
range_query_info->set_round_decimal(3);
range_query_info->set_metric_type("l2");
range_query_info->set_search_params(
R"({"nprobe": 10, "radius": 600, "range_filter": 500})");
range_query_info->set_metric_type(metric_type);
if (PositivelyRelated(metric_type)) {
range_query_info->set_search_params(
R"({"nprobe": 10, "radius": 500, "range_filter": 600})");
} else {
range_query_info->set_search_params(
R"({"nprobe": 10, "radius": 600, "range_filter": 500})");
}
auto range_plan_str = range_query_plan_node.SerializeAsString();
int64_t per_batch = 10000;
@ -82,20 +147,32 @@ TEST(GrowingIndex, Correctness) {
dataset.row_ids_.data(),
dataset.timestamps_.data(),
dataset.raw_);
auto filed_data = segmentImplPtr->get_insert_record()
.get_field_data<milvus::FloatVector>(vec);
const VectorBase* field_data = nullptr;
if (is_sparse) {
field_data = segmentImplPtr->get_insert_record()
.get_field_data<milvus::SparseFloatVector>(vec);
} else {
field_data = segmentImplPtr->get_insert_record()
.get_field_data<milvus::FloatVector>(vec);
}
auto inserted = (i + 1) * per_batch;
//once index built, chunk data will be removed
if (i < 2) {
EXPECT_EQ(filed_data->num_chunk(),
upper_div(inserted, filed_data->get_size_per_chunk()));
// once index built, chunk data will be removed.
// growing index will only be built when num rows reached
// get_build_threshold(). This value for sparse is 0, thus sparse index
// will be built since the first chunk. Dense segment buffers the first
// 2 chunks before building an index in this test case.
if (!is_sparse && i < 2) {
EXPECT_EQ(field_data->num_chunk(),
upper_div(inserted, field_data->get_size_per_chunk()));
} else {
EXPECT_EQ(filed_data->num_chunk(), 0);
EXPECT_EQ(field_data->num_chunk(), 0);
}
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 128, 1024);
auto ph_group_raw =
is_sparse ? CreateSparseFloatPlaceholderGroup(num_queries)
: CreatePlaceholderGroup(num_queries, 128, 1024);
auto plan = milvus::query::CreateSearchPlanByExpr(
*schema, plan_str.data(), plan_str.size());
@ -109,6 +186,10 @@ TEST(GrowingIndex, Correctness) {
EXPECT_EQ(sr->distances_.size(), num_queries * top_k);
EXPECT_EQ(sr->seg_offsets_.size(), num_queries * top_k);
// range search for sparse is not yet supported
if (is_sparse) {
continue;
}
auto range_plan = milvus::query::CreateSearchPlanByExpr(
*schema, range_plan_str.data(), range_plan_str.size());
auto range_ph_group = ParsePlaceholderGroup(
@ -128,12 +209,11 @@ TEST(GrowingIndex, Correctness) {
}
}
TEST(GrowingIndex, MissIndexMeta) {
TEST_P(GrowingIndexTest, MissIndexMeta) {
auto schema = std::make_shared<Schema>();
auto pk = schema->AddDebugField("pk", DataType::INT64);
auto random = schema->AddDebugField("random", DataType::DOUBLE);
auto vec = schema->AddDebugField(
"embeddings", DataType::VECTOR_FLOAT, 128, knowhere::metric::L2);
auto vec = schema->AddDebugField("embeddings", data_type, 128, metric_type);
schema->set_primary_field_id(pk);
auto& config = SegcoreConfig::default_config();
@ -142,36 +222,16 @@ TEST(GrowingIndex, MissIndexMeta) {
auto segment = CreateGrowingSegment(schema, nullptr);
}
using Param = const char*;
class GrowingIndexGetVectorTest : public ::testing::TestWithParam<Param> {
void
SetUp() override {
auto param = GetParam();
metricType = param;
}
protected:
const char* metricType;
};
INSTANTIATE_TEST_SUITE_P(IndexTypeParameters,
GrowingIndexGetVectorTest,
::testing::Values(knowhere::metric::L2,
knowhere::metric::COSINE,
knowhere::metric::IP));
TEST_P(GrowingIndexGetVectorTest, GetVector) {
TEST_P(GrowingIndexTest, GetVector) {
auto schema = std::make_shared<Schema>();
auto pk = schema->AddDebugField("pk", DataType::INT64);
auto random = schema->AddDebugField("random", DataType::DOUBLE);
auto vec = schema->AddDebugField(
"embeddings", DataType::VECTOR_FLOAT, 128, metricType);
auto vec = schema->AddDebugField("embeddings", data_type, 128, metric_type);
schema->set_primary_field_id(pk);
std::map<std::string, std::string> index_params = {
{"index_type", "IVF_FLAT"},
{"metric_type", metricType},
{"index_type", index_type},
{"metric_type", metric_type},
{"nlist", "128"}};
std::map<std::string, std::string> type_params = {{"dim", "128"}};
FieldIndexMeta fieldIndexMeta(
@ -185,30 +245,74 @@ TEST_P(GrowingIndexGetVectorTest, GetVector) {
auto segment_growing = CreateGrowingSegment(schema, metaPtr);
auto segment = dynamic_cast<SegmentGrowingImpl*>(segment_growing.get());
int64_t per_batch = 5000;
int64_t n_batch = 20;
int64_t dim = 128;
for (int64_t i = 0; i < n_batch; i++) {
auto dataset = DataGen(schema, per_batch);
auto fakevec = dataset.get_col<float>(vec);
auto offset = segment->PreInsert(per_batch);
segment->Insert(offset,
per_batch,
dataset.row_ids_.data(),
dataset.timestamps_.data(),
dataset.raw_);
auto num_inserted = (i + 1) * per_batch;
auto ids_ds = GenRandomIds(num_inserted);
auto result =
segment->bulk_subscript(vec, ids_ds->GetIds(), num_inserted);
if (data_type == DataType::VECTOR_FLOAT) {
// GetVector for VECTOR_FLOAT
int64_t per_batch = 5000;
int64_t n_batch = 20;
int64_t dim = 128;
for (int64_t i = 0; i < n_batch; i++) {
auto dataset = DataGen(schema, per_batch);
auto fakevec = dataset.get_col<float>(vec);
auto offset = segment->PreInsert(per_batch);
segment->Insert(offset,
per_batch,
dataset.row_ids_.data(),
dataset.timestamps_.data(),
dataset.raw_);
auto num_inserted = (i + 1) * per_batch;
auto ids_ds = GenRandomIds(num_inserted);
auto result =
segment->bulk_subscript(vec, ids_ds->GetIds(), num_inserted);
auto vector = result.get()->mutable_vectors()->float_vector().data();
EXPECT_TRUE(vector.size() == num_inserted * dim);
for (size_t i = 0; i < num_inserted; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < 128; ++j) {
EXPECT_TRUE(vector[i * dim + j] ==
fakevec[(id % per_batch) * dim + j]);
auto vector =
result.get()->mutable_vectors()->float_vector().data();
EXPECT_TRUE(vector.size() == num_inserted * dim);
for (size_t i = 0; i < num_inserted; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < 128; ++j) {
EXPECT_TRUE(vector[i * dim + j] ==
fakevec[(id % per_batch) * dim + j]);
}
}
}
} else if (is_sparse) {
// GetVector for VECTOR_SPARSE_FLOAT
int64_t per_batch = 5000;
int64_t n_batch = 20;
int64_t dim = 128;
for (int64_t i = 0; i < n_batch; i++) {
auto dataset = DataGen(schema, per_batch);
auto fakevec =
dataset.get_col<knowhere::sparse::SparseRow<float>>(vec);
auto offset = segment->PreInsert(per_batch);
segment->Insert(offset,
per_batch,
dataset.row_ids_.data(),
dataset.timestamps_.data(),
dataset.raw_);
auto num_inserted = (i + 1) * per_batch;
auto ids_ds = GenRandomIds(num_inserted);
auto result =
segment->bulk_subscript(vec, ids_ds->GetIds(), num_inserted);
auto vector = result.get()
->mutable_vectors()
->sparse_float_vector()
.contents();
EXPECT_TRUE(result.get()
->mutable_vectors()
->sparse_float_vector()
.contents_size() == num_inserted);
auto sparse_rows = SparseBytesToRows(vector);
for (size_t i = 0; i < num_inserted; ++i) {
auto id = ids_ds->GetIds()[i];
auto actual_row = sparse_rows[i];
auto expected_row = fakevec[(id % per_batch)];
EXPECT_TRUE(actual_row.size() == expected_row.size());
for (size_t j = 0; j < actual_row.size(); ++j) {
EXPECT_TRUE(actual_row[j].id == expected_row[j].id);
EXPECT_TRUE(actual_row[j].val == expected_row[j].val);
}
}
}
}

View File

@ -166,12 +166,6 @@ TEST_P(IndexWrapperTest, BuildAndQuery) {
ASSERT_NO_THROW(vec_index->Load(binary_set));
if (vec_field_data_type == DataType::VECTOR_SPARSE_FLOAT) {
// TODO(SPARSE): complete test in PR adding search/query to sparse
// float vector.
return;
}
milvus::SearchInfo search_info;
search_info.topk_ = K;
search_info.metric_type_ = metric_type;

View File

@ -296,11 +296,7 @@ TEST(Indexing, Naive) {
vec_index->Query(query_ds, searchInfo, view, result);
for (int i = 0; i < TOPK; ++i) {
if (result.seg_offsets_[i] < N / 2) {
std::cout << "WRONG: ";
}
std::cout << result.seg_offsets_[i] << "->" << result.distances_[i]
<< std::endl;
ASSERT_FALSE(result.seg_offsets_[i] < N / 2);
}
}
@ -315,7 +311,6 @@ class IndexTest : public ::testing::TestWithParam<Param> {
auto param = GetParam();
index_type = param.first;
metric_type = param.second;
NB = 3000;
// try to reduce the test time,
// but the large dataset is needed for the case below.
@ -330,35 +325,43 @@ class IndexTest : public ::testing::TestWithParam<Param> {
search_conf = generate_search_conf(index_type, metric_type);
range_search_conf = generate_range_search_conf(index_type, metric_type);
std::map<knowhere::MetricType, bool> is_binary_map = {
{knowhere::IndexEnum::INDEX_FAISS_IDMAP, false},
{knowhere::IndexEnum::INDEX_FAISS_IVFPQ, false},
{knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, false},
{knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, false},
{knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, true},
{knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, true},
{knowhere::IndexEnum::INDEX_HNSW, false},
{knowhere::IndexEnum::INDEX_DISKANN, false},
};
is_binary = is_binary_map[index_type];
if (is_binary) {
if (index_type == knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
is_sparse = true;
vec_field_data_type = milvus::DataType::VECTOR_SPARSE_FLOAT;
} else if (index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT ||
index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP) {
is_binary = true;
vec_field_data_type = milvus::DataType::VECTOR_BINARY;
} else {
vec_field_data_type = milvus::DataType::VECTOR_FLOAT;
}
auto dataset = GenDataset(NB, metric_type, is_binary);
if (!is_binary) {
xb_data = dataset.get_col<float>(milvus::FieldId(100));
xb_dataset = knowhere::GenDataSet(NB, DIM, xb_data.data());
xq_dataset = knowhere::GenDataSet(
NQ, DIM, xb_data.data() + DIM * query_offset);
} else {
auto dataset =
GenDatasetWithDataType(NB, metric_type, vec_field_data_type);
if (is_binary) {
// binary vector
xb_bin_data = dataset.get_col<uint8_t>(milvus::FieldId(100));
xb_dataset = knowhere::GenDataSet(NB, DIM, xb_bin_data.data());
xq_dataset = knowhere::GenDataSet(
NQ, DIM, xb_bin_data.data() + DIM * query_offset);
} else if (is_sparse) {
// sparse vector
xb_sparse_data =
dataset.get_col<knowhere::sparse::SparseRow<float>>(
milvus::FieldId(100));
xb_dataset =
knowhere::GenDataSet(NB, kTestSparseDim, xb_sparse_data.data());
xb_dataset->SetIsSparse(true);
xq_dataset = knowhere::GenDataSet(
NQ, kTestSparseDim, xb_sparse_data.data() + query_offset);
xq_dataset->SetIsSparse(true);
} else {
// float vector
xb_data = dataset.get_col<float>(milvus::FieldId(100));
xb_dataset = knowhere::GenDataSet(NB, DIM, xb_data.data());
xq_dataset = knowhere::GenDataSet(
NQ, DIM, xb_data.data() + DIM * query_offset);
}
}
@ -368,7 +371,8 @@ class IndexTest : public ::testing::TestWithParam<Param> {
protected:
std::string index_type, metric_type;
bool is_binary;
bool is_binary = false;
bool is_sparse = false;
milvus::Config build_conf;
milvus::Config load_conf;
milvus::Config search_conf;
@ -377,9 +381,10 @@ class IndexTest : public ::testing::TestWithParam<Param> {
knowhere::DataSetPtr xb_dataset;
FixedVector<float> xb_data;
FixedVector<uint8_t> xb_bin_data;
FixedVector<knowhere::sparse::SparseRow<float>> xb_sparse_data;
knowhere::DataSetPtr xq_dataset;
int64_t query_offset = 100;
int64_t NB = 3000;
int64_t NB = 3000; // will be updated to 27000 for mmap+hnsw
StorageConfig storage_config_;
};
@ -397,6 +402,9 @@ INSTANTIATE_TEST_SUITE_P(
knowhere::metric::JACCARD),
std::pair(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP,
knowhere::metric::JACCARD),
std::pair(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX,
knowhere::metric::IP),
std::pair(knowhere::IndexEnum::INDEX_SPARSE_WAND, knowhere::metric::IP),
#ifdef BUILD_DISK_ANN
std::pair(knowhere::IndexEnum::INDEX_DISKANN, knowhere::metric::L2),
#endif
@ -506,7 +514,9 @@ TEST_P(IndexTest, BuildAndQuery) {
load_conf["index_files"] = index_files;
ASSERT_NO_THROW(vec_index->Load(milvus::tracer::TraceContext{}, load_conf));
EXPECT_EQ(vec_index->Count(), NB);
EXPECT_EQ(vec_index->GetDim(), DIM);
if (!is_sparse) {
EXPECT_EQ(vec_index->GetDim(), DIM);
}
milvus::SearchInfo search_info;
search_info.topk_ = K;
@ -518,11 +528,19 @@ TEST_P(IndexTest, BuildAndQuery) {
EXPECT_EQ(result.unity_topK_, K);
EXPECT_EQ(result.distances_.size(), NQ * K);
EXPECT_EQ(result.seg_offsets_.size(), NQ * K);
if (!is_binary) {
EXPECT_EQ(result.seg_offsets_[0], query_offset);
if (metric_type == knowhere::metric::L2) {
// for L2 metric each vector is closest to itself
for (int i = 0; i < NQ; i++) {
EXPECT_EQ(result.seg_offsets_[i * K], query_offset + i);
}
// for other metrics we can't verify the correctness unless we perform
// brute force search to get the ground truth.
}
if (!is_sparse) {
// sparse doesn't support range search yet
search_info.search_params_ = range_search_conf;
vec_index->Query(xq_dataset, search_info, nullptr, result);
}
search_info.search_params_ = range_search_conf;
vec_index->Query(xq_dataset, search_info, nullptr, result);
}
TEST_P(IndexTest, Mmap) {
@ -623,7 +641,9 @@ TEST_P(IndexTest, GetVector) {
} else {
vec_index->Load(milvus::tracer::TraceContext{}, load_conf);
}
EXPECT_EQ(vec_index->GetDim(), DIM);
if (!is_sparse) {
EXPECT_EQ(vec_index->GetDim(), DIM);
}
EXPECT_EQ(vec_index->Count(), NB);
if (!vec_index->HasRawData()) {
@ -631,27 +651,37 @@ TEST_P(IndexTest, GetVector) {
}
auto ids_ds = GenRandomIds(NB);
auto results = vec_index->GetVector(ids_ds);
EXPECT_TRUE(results.size() > 0);
if (!is_binary) {
std::vector<float> result_vectors(results.size() / (sizeof(float)));
memcpy(result_vectors.data(), results.data(), results.size());
EXPECT_TRUE(result_vectors.size() == xb_data.size());
for (size_t i = 0; i < NB; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < DIM; ++j) {
EXPECT_TRUE(result_vectors[i * DIM + j] ==
xb_data[id * DIM + j]);
}
}
} else {
EXPECT_TRUE(results.size() == xb_bin_data.size());
if (is_binary) {
auto results = vec_index->GetVector(ids_ds);
EXPECT_EQ(results.size(), xb_bin_data.size());
const auto data_bytes = DIM / 8;
for (size_t i = 0; i < NB; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < data_bytes; ++j) {
EXPECT_TRUE(results[i * data_bytes + j] ==
xb_bin_data[id * data_bytes + j]);
ASSERT_EQ(results[i * data_bytes + j],
xb_bin_data[id * data_bytes + j]);
}
}
} else if (is_sparse) {
auto sparse_rows = vec_index->GetSparseVector(ids_ds);
for (size_t i = 0; i < NB; ++i) {
auto id = ids_ds->GetIds()[i];
auto& row = sparse_rows[i];
ASSERT_EQ(row.size(), xb_sparse_data[id].size());
for (size_t j = 0; j < row.size(); ++j) {
ASSERT_EQ(row[j].id, xb_sparse_data[id][j].id);
ASSERT_EQ(row[j].val, xb_sparse_data[id][j].val);
}
}
} else {
auto results = vec_index->GetVector(ids_ds);
std::vector<float> result_vectors(results.size() / (sizeof(float)));
memcpy(result_vectors.data(), results.data(), results.size());
ASSERT_EQ(result_vectors.size(), xb_data.size());
for (size_t i = 0; i < NB; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < DIM; ++j) {
ASSERT_EQ(result_vectors[i * DIM + j], xb_data[id * DIM + j]);
}
}
}

View File

@ -62,7 +62,7 @@ class TypedOffsetOrderedArrayTest : public testing::Test {
};
using TypeOfPks = testing::Types<int64_t, std::string>;
TYPED_TEST_CASE_P(TypedOffsetOrderedArrayTest);
TYPED_TEST_SUITE_P(TypedOffsetOrderedArrayTest);
TYPED_TEST_P(TypedOffsetOrderedArrayTest, find_first) {
std::vector<int64_t> offsets;
@ -117,5 +117,5 @@ TYPED_TEST_P(TypedOffsetOrderedArrayTest, find_first) {
ASSERT_EQ(0, offsets.size());
}
REGISTER_TYPED_TEST_CASE_P(TypedOffsetOrderedArrayTest, find_first);
INSTANTIATE_TYPED_TEST_CASE_P(Prefix, TypedOffsetOrderedArrayTest, TypeOfPks);
REGISTER_TYPED_TEST_SUITE_P(TypedOffsetOrderedArrayTest, find_first);
INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, TypedOffsetOrderedArrayTest, TypeOfPks);

View File

@ -57,7 +57,7 @@ class TypedOffsetOrderedMapTest : public testing::Test {
};
using TypeOfPks = testing::Types<int64_t, std::string>;
TYPED_TEST_CASE_P(TypedOffsetOrderedMapTest);
TYPED_TEST_SUITE_P(TypedOffsetOrderedMapTest);
TYPED_TEST_P(TypedOffsetOrderedMapTest, find_first) {
std::vector<int64_t> offsets;
@ -110,5 +110,5 @@ TYPED_TEST_P(TypedOffsetOrderedMapTest, find_first) {
ASSERT_EQ(0, offsets.size());
}
REGISTER_TYPED_TEST_CASE_P(TypedOffsetOrderedMapTest, find_first);
INSTANTIATE_TYPED_TEST_CASE_P(Prefix, TypedOffsetOrderedMapTest, TypeOfPks);
REGISTER_TYPED_TEST_SUITE_P(TypedOffsetOrderedMapTest, find_first);
INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, TypedOffsetOrderedMapTest, TypeOfPks);

View File

@ -29,12 +29,34 @@ RetrieveUsingDefaultOutputSize(SegmentInterface* segment,
return segment->Retrieve(plan, timestamp, DEFAULT_MAX_OUTPUT_SIZE);
}
TEST(Retrieve, AutoID) {
using Param = DataType;
class RetrieveTest : public ::testing::TestWithParam<Param> {
public:
void
SetUp() override {
data_type = GetParam();
metric_type = datatype_is_sparse_vector(data_type)
? knowhere::metric::IP
: knowhere::metric::L2;
is_sparse = datatype_is_sparse_vector(data_type);
}
DataType data_type;
knowhere::MetricType metric_type;
bool is_sparse = false;
};
INSTANTIATE_TEST_SUITE_P(RetrieveTest,
RetrieveTest,
::testing::Values(DataType::VECTOR_FLOAT,
DataType::VECTOR_SPARSE_FLOAT));
TEST_P(RetrieveTest, AutoID) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector_64", data_type, DIM, metric_type);
schema->set_primary_field_id(fid_64);
int64_t N = 100;
@ -48,12 +70,10 @@ TEST(Retrieve, AutoID) {
auto plan = std::make_unique<query::RetrievePlan>(*schema);
std::vector<proto::plan::GenericValue> values;
{
for (int i = 0; i < req_size; ++i) {
proto::plan::GenericValue val;
val.set_int64_val(i64_col[choose(i)]);
values.push_back(val);
}
for (int i = 0; i < req_size; ++i) {
proto::plan::GenericValue val;
val.set_int64_val(i64_col[choose(i)]);
values.push_back(val);
}
auto term_expr = std::make_shared<milvus::expr::TermFilterExpr>(
milvus::expr::ColumnInfo(
@ -72,11 +92,6 @@ TEST(Retrieve, AutoID) {
Assert(field0.has_scalars());
auto field0_data = field0.scalars().long_data();
for (int i = 0; i < req_size; ++i) {
auto index = choose(i);
auto data = field0_data.data(i);
}
for (int i = 0; i < req_size; ++i) {
auto index = choose(i);
auto data = field0_data.data(i);
@ -85,16 +100,21 @@ TEST(Retrieve, AutoID) {
auto field1 = retrieve_results->fields_data(1);
Assert(field1.has_vectors());
auto field1_data = field1.vectors().float_vector();
ASSERT_EQ(field1_data.data_size(), DIM * req_size);
if (!is_sparse) {
auto field1_data = field1.vectors().float_vector();
ASSERT_EQ(field1_data.data_size(), DIM * req_size);
} else {
auto field1_data = field1.vectors().sparse_float_vector();
ASSERT_EQ(field1_data.contents_size(), req_size);
}
}
TEST(Retrieve, AutoID2) {
TEST_P(RetrieveTest, AutoID2) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector_64", data_type, DIM, metric_type);
schema->set_primary_field_id(fid_64);
int64_t N = 100;
@ -140,16 +160,21 @@ TEST(Retrieve, AutoID2) {
auto field1 = retrieve_results->fields_data(1);
Assert(field1.has_vectors());
auto field1_data = field1.vectors().float_vector();
ASSERT_EQ(field1_data.data_size(), DIM * req_size);
if (!is_sparse) {
auto field1_data = field1.vectors().float_vector();
ASSERT_EQ(field1_data.data_size(), DIM * req_size);
} else {
auto field1_data = field1.vectors().sparse_float_vector();
ASSERT_EQ(field1_data.contents_size(), req_size);
}
}
TEST(Retrieve, NotExist) {
TEST_P(RetrieveTest, NotExist) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector_64", data_type, DIM, metric_type);
schema->set_primary_field_id(fid_64);
int64_t N = 100;
@ -200,16 +225,21 @@ TEST(Retrieve, NotExist) {
auto field1 = retrieve_results->fields_data(1);
Assert(field1.has_vectors());
auto field1_data = field1.vectors().float_vector();
ASSERT_EQ(field1_data.data_size(), DIM * req_size);
if (!is_sparse) {
auto field1_data = field1.vectors().float_vector();
ASSERT_EQ(field1_data.data_size(), DIM * req_size);
} else {
auto field1_data = field1.vectors().sparse_float_vector();
ASSERT_EQ(field1_data.contents_size(), req_size);
}
}
TEST(Retrieve, Empty) {
TEST_P(RetrieveTest, Empty) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector_64", data_type, DIM, metric_type);
schema->set_primary_field_id(fid_64);
int64_t N = 100;
@ -246,15 +276,19 @@ TEST(Retrieve, Empty) {
Assert(field0.has_scalars());
auto field0_data = field0.scalars().long_data();
Assert(field0_data.data_size() == 0);
Assert(field1.vectors().float_vector().data_size() == 0);
if (!is_sparse) {
ASSERT_EQ(field1.vectors().float_vector().data_size(), 0);
} else {
ASSERT_EQ(field1.vectors().sparse_float_vector().contents_size(), 0);
}
}
TEST(Retrieve, Limit) {
TEST_P(RetrieveTest, Limit) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector_64", data_type, DIM, metric_type);
schema->set_primary_field_id(fid_64);
int64_t N = 101;
@ -285,18 +319,22 @@ TEST(Retrieve, Limit) {
auto field0 = retrieve_results->fields_data(0);
auto field2 = retrieve_results->fields_data(2);
Assert(field0.scalars().long_data().data_size() == N);
Assert(field2.vectors().float_vector().data_size() == N * DIM);
if (!is_sparse) {
Assert(field2.vectors().float_vector().data_size() == N * DIM);
} else {
Assert(field2.vectors().sparse_float_vector().contents_size() == N);
}
}
TEST(Retrieve, FillEntry) {
TEST_P(RetrieveTest, FillEntry) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_bool = schema->AddDebugField("bool", DataType::BOOL);
auto fid_f32 = schema->AddDebugField("f32", DataType::FLOAT);
auto fid_f64 = schema->AddDebugField("f64", DataType::DOUBLE);
auto fid_vec32 = schema->AddDebugField(
"vector_32", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector", data_type, DIM, knowhere::metric::L2);
auto fid_vecbin = schema->AddDebugField(
"vec_bin", DataType::VECTOR_BINARY, DIM, knowhere::metric::L2);
schema->set_primary_field_id(fid_64);
@ -323,7 +361,7 @@ TEST(Retrieve, FillEntry) {
fid_bool,
fid_f32,
fid_f64,
fid_vec32,
fid_vec,
fid_vecbin};
plan->field_ids_ = target_fields;
EXPECT_THROW(segment->Retrieve(plan.get(), N, 1), std::runtime_error);
@ -333,12 +371,12 @@ TEST(Retrieve, FillEntry) {
Assert(retrieve_results->fields_data_size() == target_fields.size());
}
TEST(Retrieve, LargeTimestamp) {
TEST_P(RetrieveTest, LargeTimestamp) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector_64", data_type, DIM, metric_type);
schema->set_primary_field_id(fid_64);
int64_t N = 100;
@ -392,16 +430,21 @@ TEST(Retrieve, LargeTimestamp) {
Assert(field_data.vectors().float_vector().data_size() ==
target_num * DIM);
}
if (DataType(field_data.type()) == DataType::VECTOR_SPARSE_FLOAT) {
Assert(field_data.vectors()
.sparse_float_vector()
.contents_size() == target_num);
}
}
}
}
TEST(Retrieve, Delete) {
TEST_P(RetrieveTest, Delete) {
auto schema = std::make_shared<Schema>();
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
auto DIM = 16;
auto fid_vec = schema->AddDebugField(
"vector_64", DataType::VECTOR_FLOAT, DIM, knowhere::metric::L2);
auto fid_vec =
schema->AddDebugField("vector_64", data_type, DIM, metric_type);
schema->set_primary_field_id(fid_64);
auto fid_ts = schema->AddDebugField("Timestamp", DataType::INT64);
@ -465,8 +508,13 @@ TEST(Retrieve, Delete) {
auto field2 = retrieve_results->fields_data(2);
Assert(field2.has_vectors());
auto field2_data = field2.vectors().float_vector();
ASSERT_EQ(field2_data.data_size(), DIM * req_size);
if (!is_sparse) {
auto field2_data = field2.vectors().float_vector();
ASSERT_EQ(field2_data.data_size(), DIM * req_size);
} else {
auto field2_data = field2.vectors().sparse_float_vector();
ASSERT_EQ(field2_data.contents_size(), req_size);
}
}
int64_t row_count = 0;
@ -512,7 +560,12 @@ TEST(Retrieve, Delete) {
auto field2 = retrieve_results->fields_data(2);
Assert(field2.has_vectors());
auto field2_data = field2.vectors().float_vector();
ASSERT_EQ(field2_data.data_size(), DIM * size);
if (!is_sparse) {
auto field2_data = field2.vectors().float_vector();
ASSERT_EQ(field2_data.data_size(), DIM * size);
} else {
auto field2_data = field2.vectors().sparse_float_vector();
ASSERT_EQ(field2_data.contents_size(), size);
}
}
}

View File

@ -41,7 +41,7 @@ class TypedScalarIndexTest : public ::testing::Test {
// }
};
TYPED_TEST_CASE_P(TypedScalarIndexTest);
TYPED_TEST_SUITE_P(TypedScalarIndexTest);
TYPED_TEST_P(TypedScalarIndexTest, Dummy) {
using T = TypeParam;
@ -213,18 +213,18 @@ TYPED_TEST_P(TypedScalarIndexTest, Codec) {
using ScalarT =
::testing::Types<int8_t, int16_t, int32_t, int64_t, float, double>;
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexTest,
Dummy,
Constructor,
Count,
In,
NotIn,
Range,
Codec,
Reverse,
HasRawData);
REGISTER_TYPED_TEST_SUITE_P(TypedScalarIndexTest,
Dummy,
Constructor,
Count,
In,
NotIn,
Range,
Codec,
Reverse,
HasRawData);
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexTest, ScalarT);
INSTANTIATE_TYPED_TEST_SUITE_P(ArithmeticCheck, TypedScalarIndexTest, ScalarT);
template <typename T>
class TypedScalarIndexTestV2 : public ::testing::Test {
@ -344,7 +344,7 @@ struct TypedScalarIndexTestV2<double>::Helper {
using C = arrow::DoubleType;
};
TYPED_TEST_CASE_P(TypedScalarIndexTestV2);
TYPED_TEST_SUITE_P(TypedScalarIndexTestV2);
TYPED_TEST_P(TypedScalarIndexTestV2, Base) {
using T = TypeParam;
@ -386,6 +386,8 @@ TYPED_TEST_P(TypedScalarIndexTestV2, Base) {
}
}
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexTestV2, Base);
REGISTER_TYPED_TEST_SUITE_P(TypedScalarIndexTestV2, Base);
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexTestV2, ScalarT);
INSTANTIATE_TYPED_TEST_SUITE_P(ArithmeticCheck,
TypedScalarIndexTestV2,
ScalarT);

View File

@ -86,7 +86,7 @@ class TypedScalarIndexCreatorTest : public ::testing::Test {
using ScalarT = ::testing::
Types<bool, int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
TYPED_TEST_CASE_P(TypedScalarIndexCreatorTest);
TYPED_TEST_SUITE_P(TypedScalarIndexCreatorTest);
TYPED_TEST_P(TypedScalarIndexCreatorTest, Dummy) {
using T = TypeParam;
@ -149,11 +149,11 @@ TYPED_TEST_P(TypedScalarIndexCreatorTest, Codec) {
}
}
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexCreatorTest,
Dummy,
Constructor,
Codec);
REGISTER_TYPED_TEST_SUITE_P(TypedScalarIndexCreatorTest,
Dummy,
Constructor,
Codec);
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck,
TypedScalarIndexCreatorTest,
ScalarT);
INSTANTIATE_TYPED_TEST_SUITE_P(ArithmeticCheck,
TypedScalarIndexCreatorTest,
ScalarT);

View File

@ -34,6 +34,14 @@ using milvus::segcore::LoadIndexInfo;
const int64_t ROW_COUNT = 10 * 1000;
const int64_t BIAS = 4200;
using Param = std::string;
class SealedTest : public ::testing::TestWithParam<Param> {
public:
void
SetUp() override {
}
};
TEST(Sealed, without_predicate) {
auto schema = std::make_shared<Schema>();
auto dim = 16;

View File

@ -14,5 +14,5 @@ constexpr int64_t TestChunkSize = 32 * 1024;
constexpr char TestLocalPath[] = "/tmp/milvus/local_data/";
constexpr char TestRemotePath[] = "/tmp/milvus/remote_data";
constexpr int64_t kTestSparseDim = 10000;
constexpr float kTestSparseVectorDensity = 0.0003;
constexpr int64_t kTestSparseDim = 1000;
constexpr float kTestSparseVectorDensity = 0.003;

View File

@ -27,7 +27,6 @@
#include "index/ScalarIndexSort.h"
#include "index/StringIndexSort.h"
#include "index/VectorMemIndex.h"
#include "query/SearchOnIndex.h"
#include "segcore/Collection.h"
#include "segcore/SegmentGrowingImpl.h"
#include "segcore/SegmentSealedImpl.h"
@ -247,8 +246,8 @@ struct GeneratedData {
inline std::unique_ptr<knowhere::sparse::SparseRow<float>[]>
GenerateRandomSparseFloatVector(size_t rows,
size_t cols,
float density,
size_t cols = kTestSparseDim,
float density = kTestSparseVectorDensity,
int seed = 42) {
int32_t num_elements = static_cast<int32_t>(rows * cols * density);
@ -1144,6 +1143,23 @@ translate_text_plan_to_binary_plan(const char* text_plan) {
return ret;
}
// we have lots of tests with literal string plan with hard coded metric type,
// so creating a helper function to replace metric type for different metrics.
inline std::vector<char>
replace_metric_and_translate_text_plan_to_binary_plan(
std::string plan, knowhere::MetricType metric_type) {
if (metric_type != knowhere::metric::L2) {
std::string replace = R"(metric_type: "L2")";
std::string target = "metric_type: \"" + metric_type + "\"";
size_t pos = 0;
while ((pos = plan.find(replace, pos)) != std::string::npos) {
plan.replace(pos, replace.length(), target);
pos += target.length();
}
}
return translate_text_plan_to_binary_plan(plan.c_str());
}
inline auto
GenTss(int64_t num, int64_t begin_ts) {
std::vector<Timestamp> tss(num, 0);

View File

@ -102,6 +102,7 @@ generate_build_conf(const milvus::IndexType& index_type,
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
return knowhere::Json{
{knowhere::meta::METRIC_TYPE, metric_type},
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},
};
}
return knowhere::Json();

View File

@ -35,6 +35,7 @@ enum VectorType {
FloatVector = 1;
Float16Vector = 2;
BFloat16Vector = 3;
SparseFloatVector = 4;
};
message GenericValue {