enhance: Use template to remove unittest duplication (#39144)

Issue: #38666

Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
pull/39212/head
Cai Yudong 2025-01-13 09:58:57 +08:00 committed by GitHub
parent 032292a432
commit 2a02bbe3ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 309 additions and 1169 deletions

View File

@ -15,43 +15,107 @@
// limitations under the License.
#pragma once
#include "Types.h"
#include <string>
#include <type_traits>
#include "Array.h"
#include "Types.h"
#include "common/type_c.h"
#include "pb/common.pb.h"
#include "pb/plan.pb.h"
#include "pb/schema.pb.h"
namespace milvus {
#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
using elem_type = std::conditional_t< \
std::is_same_v<TraitType, milvus::BinaryVector>, \
BinaryVector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::Float16Vector>, \
Float16Vector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
BFloat16Vector::embedded_type, \
FloatVector::embedded_type>>>;
#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
auto schema_data_type = \
std::is_same_v<TraitType, milvus::FloatVector> \
? FloatVector::schema_data_type \
: std::is_same_v<TraitType, milvus::Float16Vector> \
? Float16Vector::schema_data_type \
: std::is_same_v<TraitType, milvus::BFloat16Vector> \
? BFloat16Vector::schema_data_type \
: BinaryVector::schema_data_type;
class VectorTrait {};
class FloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_FLOAT;
static constexpr auto c_data_type = CDataType::FloatVector;
static constexpr auto schema_data_type =
proto::schema::DataType::FloatVector;
static constexpr auto vector_type = proto::plan::VectorType::FloatVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::FloatVector;
};
class BinaryVector : public VectorTrait {
public:
using embedded_type = uint8_t;
static constexpr auto metric_type = DataType::VECTOR_BINARY;
static constexpr int32_t dim_factor = 8;
static constexpr auto data_type = DataType::VECTOR_BINARY;
static constexpr auto c_data_type = CDataType::BinaryVector;
static constexpr auto schema_data_type =
proto::schema::DataType::BinaryVector;
static constexpr auto vector_type = proto::plan::VectorType::BinaryVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::BinaryVector;
};
class Float16Vector : public VectorTrait {
public:
using embedded_type = float16;
static constexpr auto metric_type = DataType::VECTOR_FLOAT16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_FLOAT16;
static constexpr auto c_data_type = CDataType::Float16Vector;
static constexpr auto schema_data_type =
proto::schema::DataType::Float16Vector;
static constexpr auto vector_type = proto::plan::VectorType::Float16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::Float16Vector;
};
class BFloat16Vector : public VectorTrait {
public:
using embedded_type = bfloat16;
static constexpr auto metric_type = DataType::VECTOR_BFLOAT16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_BFLOAT16;
static constexpr auto c_data_type = CDataType::BFloat16Vector;
static constexpr auto schema_data_type =
proto::schema::DataType::BFloat16Vector;
static constexpr auto vector_type = proto::plan::VectorType::BFloat16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::BFloat16Vector;
};
class SparseFloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_SPARSE_FLOAT;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_SPARSE_FLOAT;
static constexpr auto c_data_type = CDataType::SparseFloatVector;
static constexpr auto schema_data_type =
proto::schema::DataType::SparseFloatVector;
static constexpr auto vector_type =
proto::plan::VectorType::SparseFloatVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::SparseFloatVector;
};
template <typename T>

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,7 @@ TEST(CApiTest, StreamReduce) {
int N = 300;
int topK = 100;
int num_queries = 2;
auto collection = NewCollection(get_default_schema_config());
auto collection = NewCollection(get_default_schema_config().c_str());
//1. set up segments
CSegmentInterface segment_1;

View File

@ -118,7 +118,8 @@ TEST(Float16, ExecWithoutPredicateFlat) {
auto vec_ptr = dataset.get_col<float16>(vec_fid);
auto num_queries = 5;
auto ph_group_raw = CreateFloat16PlaceholderGroup(num_queries, 32, 1024);
auto ph_group_raw =
CreatePlaceholderGroup<milvus::Float16Vector>(num_queries, 32, 1024);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp timestamp = 1000000;
@ -274,7 +275,8 @@ TEST(Float16, ExecWithPredicate) {
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreateFloat16PlaceholderGroup(num_queries, 16, 1024);
auto ph_group_raw =
CreatePlaceholderGroup<milvus::Float16Vector>(num_queries, 16, 1024);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
@ -354,7 +356,8 @@ TEST(BFloat16, ExecWithoutPredicateFlat) {
auto vec_ptr = dataset.get_col<bfloat16>(vec_fid);
auto num_queries = 5;
auto ph_group_raw = CreateBFloat16PlaceholderGroup(num_queries, 32, 1024);
auto ph_group_raw =
CreatePlaceholderGroup<milvus::BFloat16Vector>(num_queries, 32, 1024);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp timestamp = 1000000;
@ -510,7 +513,8 @@ TEST(BFloat16, ExecWithPredicate) {
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreateBFloat16PlaceholderGroup(num_queries, 16, 1024);
auto ph_group_raw =
CreatePlaceholderGroup<milvus::BFloat16Vector>(num_queries, 16, 1024);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp timestamp = 1000000;

View File

@ -725,7 +725,7 @@ TEST(Query, ExecWithPredicateBinary) {
auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 5;
auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(
auto ph_group_raw = CreatePlaceholderGroupFromBlob<milvus::BinaryVector>(
num_queries, 512, vec_ptr.data() + 1024 * 512 / 8);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());

View File

@ -743,26 +743,6 @@ DataGenForJsonArray(SchemaPtr schema,
return res;
}
inline auto
CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
namespace ser = milvus::proto::common;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::FloatVector);
std::normal_distribution<double> dis(0, 1);
std::default_random_engine e(seed);
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
return raw_group;
}
inline auto
CreatePlaceholderGroup(int64_t num_queries,
int dim,
@ -782,148 +762,57 @@ CreatePlaceholderGroup(int64_t num_queries,
return raw_group;
}
inline auto
CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const float* src) {
template <class TraitType = milvus::FloatVector>
auto
CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
assert(dim % 8 == 0);
}
namespace ser = milvus::proto::common;
GET_ELEM_TYPE_FOR_VECTOR_TRAIT
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::FloatVector);
value->set_type(TraitType::placeholder_type);
// TODO caiyd: need update for Int8Vector
std::normal_distribution<double> dis(0, 1);
std::default_random_engine e(seed);
for (int i = 0; i < num_queries; ++i) {
std::vector<elem_type> vec;
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
vec.push_back(e());
} else {
vec.push_back(elem_type(dis(e)));
}
}
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
}
return raw_group;
}
template <class TraitType = milvus::FloatVector>
inline auto
CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) {
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
assert(dim % 8 == 0);
}
namespace ser = milvus::proto::common;
GET_ELEM_TYPE_FOR_VECTOR_TRAIT
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(TraitType::placeholder_type);
int64_t src_index = 0;
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(src[src_index++]);
std::vector<elem_type> vec;
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
vec.push_back(((elem_type*)src)[src_index++]);
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
return raw_group;
}
inline auto
CreateBinaryPlaceholderGroup(int64_t num_queries,
int64_t dim,
int64_t seed = 42) {
assert(dim % 8 == 0);
namespace ser = milvus::proto::common;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::BinaryVector);
std::default_random_engine e(seed);
for (int i = 0; i < num_queries; ++i) {
std::vector<uint8_t> vec;
for (int d = 0; d < dim / 8; ++d) {
vec.push_back(e());
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size());
}
return raw_group;
}
inline auto
CreateBinaryPlaceholderGroupFromBlob(int64_t num_queries,
int64_t dim,
const uint8_t* ptr) {
assert(dim % 8 == 0);
namespace ser = milvus::proto::common;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::BinaryVector);
for (int i = 0; i < num_queries; ++i) {
std::vector<uint8_t> vec;
for (int d = 0; d < dim / 8; ++d) {
vec.push_back(*ptr);
++ptr;
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size());
}
return raw_group;
}
inline auto
CreateFloat16PlaceholderGroup(int64_t num_queries,
int64_t dim,
int64_t seed = 42) {
namespace ser = milvus::proto::common;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::Float16Vector);
std::normal_distribution<double> dis(0, 1);
std::default_random_engine e(seed);
for (int i = 0; i < num_queries; ++i) {
std::vector<float16> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(float16(dis(e)));
}
value->add_values(vec.data(), vec.size() * sizeof(float16));
}
return raw_group;
}
inline auto
CreateFloat16PlaceholderGroupFromBlob(int64_t num_queries,
int64_t dim,
const float16* ptr) {
namespace ser = milvus::proto::common;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::Float16Vector);
for (int i = 0; i < num_queries; ++i) {
std::vector<float16> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(*ptr);
++ptr;
}
value->add_values(vec.data(), vec.size() * sizeof(float16));
}
return raw_group;
}
inline auto
CreateBFloat16PlaceholderGroup(int64_t num_queries,
int64_t dim,
int64_t seed = 42) {
namespace ser = milvus::proto::common;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::BFloat16Vector);
std::normal_distribution<double> dis(0, 1);
std::default_random_engine e(seed);
for (int i = 0; i < num_queries; ++i) {
std::vector<bfloat16> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(bfloat16(dis(e)));
}
value->add_values(vec.data(), vec.size() * sizeof(bfloat16));
}
return raw_group;
}
inline auto
CreateBFloat16PlaceholderGroupFromBlob(int64_t num_queries,
int64_t dim,
const bfloat16* ptr) {
namespace ser = milvus::proto::common;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::BFloat16Vector);
for (int i = 0; i < num_queries; ++i) {
std::vector<bfloat16> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(*ptr);
++ptr;
}
value->add_values(vec.data(), vec.size() * sizeof(bfloat16));
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
}
return raw_group;
}

View File

@ -23,6 +23,7 @@
#include "common/Types.h"
#include "common/type_c.h"
#include "common/VectorTrait.h"
#include "pb/plan.pb.h"
#include "segcore/Collection.h"
#include "segcore/reduce/Reduce.h"
@ -32,7 +33,6 @@
#include "futures/future_c.h"
#include "DataGen.h"
#include "PbHelper.h"
#include "c_api_test_utils.h"
#include "indexbuilder_test_utils.h"
using namespace milvus;
@ -66,26 +66,30 @@ generate_max_float_query_data(int all_nq, int max_float_nq) {
return blob;
}
template <class TraitType = milvus::FloatVector>
std::string
generate_query_data(int nq) {
namespace ser = milvus::proto::common;
GET_ELEM_TYPE_FOR_VECTOR_TRAIT
std::default_random_engine e(67);
int dim = DIM;
std::normal_distribution<double> dis(0.0, 1.0);
std::uniform_int_distribution<int8_t> dis(-128, 127);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::FloatVector);
value->set_type(TraitType::placeholder_type);
for (int i = 0; i < nq; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
std::vector<elem_type> vec;
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
vec.push_back((elem_type)dis(e));
}
value->add_values(vec.data(), vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
}
auto blob = raw_group.SerializeAsString();
return blob;
}
void
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results,
int group_size = 1) {
@ -117,13 +121,14 @@ CheckSearchResultDuplicate(const std::vector<CSearchResult>& results,
}
}
const char*
template <class TraitType = milvus::FloatVector>
const std::string
get_default_schema_config() {
static std::string conf = R"(name: "default-collection"
auto fmt = boost::format(R"(name: "default-collection"
fields: <
fieldID: 100
name: "fakevec"
data_type: FloatVector
data_type: %1%
type_params: <
key: "dim"
value: "16"
@ -138,9 +143,9 @@ get_default_schema_config() {
name: "age"
data_type: Int64
is_primary_key: true
>)";
static std::string fake_conf = "";
return conf.c_str();
>)") %
(int(TraitType::data_type));
return fmt.str();
}
const char*