mirror of https://github.com/milvus-io/milvus.git
enhance: Use template to remove unittest duplication (#39144)
Issue: #38666 Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>pull/39212/head
parent
032292a432
commit
2a02bbe3ee
|
@ -15,43 +15,107 @@
|
|||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "Types.h"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "Array.h"
|
||||
#include "Types.h"
|
||||
#include "common/type_c.h"
|
||||
#include "pb/common.pb.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "pb/schema.pb.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
|
||||
using elem_type = std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::BinaryVector>, \
|
||||
BinaryVector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::Float16Vector>, \
|
||||
Float16Vector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
|
||||
BFloat16Vector::embedded_type, \
|
||||
FloatVector::embedded_type>>>;
|
||||
|
||||
#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
|
||||
auto schema_data_type = \
|
||||
std::is_same_v<TraitType, milvus::FloatVector> \
|
||||
? FloatVector::schema_data_type \
|
||||
: std::is_same_v<TraitType, milvus::Float16Vector> \
|
||||
? Float16Vector::schema_data_type \
|
||||
: std::is_same_v<TraitType, milvus::BFloat16Vector> \
|
||||
? BFloat16Vector::schema_data_type \
|
||||
: BinaryVector::schema_data_type;
|
||||
|
||||
class VectorTrait {};
|
||||
|
||||
class FloatVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = float;
|
||||
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
|
||||
static constexpr int32_t dim_factor = 1;
|
||||
static constexpr auto data_type = DataType::VECTOR_FLOAT;
|
||||
static constexpr auto c_data_type = CDataType::FloatVector;
|
||||
static constexpr auto schema_data_type =
|
||||
proto::schema::DataType::FloatVector;
|
||||
static constexpr auto vector_type = proto::plan::VectorType::FloatVector;
|
||||
static constexpr auto placeholder_type =
|
||||
proto::common::PlaceholderType::FloatVector;
|
||||
};
|
||||
|
||||
class BinaryVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = uint8_t;
|
||||
static constexpr auto metric_type = DataType::VECTOR_BINARY;
|
||||
static constexpr int32_t dim_factor = 8;
|
||||
static constexpr auto data_type = DataType::VECTOR_BINARY;
|
||||
static constexpr auto c_data_type = CDataType::BinaryVector;
|
||||
static constexpr auto schema_data_type =
|
||||
proto::schema::DataType::BinaryVector;
|
||||
static constexpr auto vector_type = proto::plan::VectorType::BinaryVector;
|
||||
static constexpr auto placeholder_type =
|
||||
proto::common::PlaceholderType::BinaryVector;
|
||||
};
|
||||
|
||||
class Float16Vector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = float16;
|
||||
static constexpr auto metric_type = DataType::VECTOR_FLOAT16;
|
||||
static constexpr int32_t dim_factor = 1;
|
||||
static constexpr auto data_type = DataType::VECTOR_FLOAT16;
|
||||
static constexpr auto c_data_type = CDataType::Float16Vector;
|
||||
static constexpr auto schema_data_type =
|
||||
proto::schema::DataType::Float16Vector;
|
||||
static constexpr auto vector_type = proto::plan::VectorType::Float16Vector;
|
||||
static constexpr auto placeholder_type =
|
||||
proto::common::PlaceholderType::Float16Vector;
|
||||
};
|
||||
|
||||
class BFloat16Vector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = bfloat16;
|
||||
static constexpr auto metric_type = DataType::VECTOR_BFLOAT16;
|
||||
static constexpr int32_t dim_factor = 1;
|
||||
static constexpr auto data_type = DataType::VECTOR_BFLOAT16;
|
||||
static constexpr auto c_data_type = CDataType::BFloat16Vector;
|
||||
static constexpr auto schema_data_type =
|
||||
proto::schema::DataType::BFloat16Vector;
|
||||
static constexpr auto vector_type = proto::plan::VectorType::BFloat16Vector;
|
||||
static constexpr auto placeholder_type =
|
||||
proto::common::PlaceholderType::BFloat16Vector;
|
||||
};
|
||||
|
||||
class SparseFloatVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = float;
|
||||
static constexpr auto metric_type = DataType::VECTOR_SPARSE_FLOAT;
|
||||
static constexpr int32_t dim_factor = 1;
|
||||
static constexpr auto data_type = DataType::VECTOR_SPARSE_FLOAT;
|
||||
static constexpr auto c_data_type = CDataType::SparseFloatVector;
|
||||
static constexpr auto schema_data_type =
|
||||
proto::schema::DataType::SparseFloatVector;
|
||||
static constexpr auto vector_type =
|
||||
proto::plan::VectorType::SparseFloatVector;
|
||||
static constexpr auto placeholder_type =
|
||||
proto::common::PlaceholderType::SparseFloatVector;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,7 +17,7 @@ TEST(CApiTest, StreamReduce) {
|
|||
int N = 300;
|
||||
int topK = 100;
|
||||
int num_queries = 2;
|
||||
auto collection = NewCollection(get_default_schema_config());
|
||||
auto collection = NewCollection(get_default_schema_config().c_str());
|
||||
|
||||
//1. set up segments
|
||||
CSegmentInterface segment_1;
|
||||
|
|
|
@ -118,7 +118,8 @@ TEST(Float16, ExecWithoutPredicateFlat) {
|
|||
auto vec_ptr = dataset.get_col<float16>(vec_fid);
|
||||
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreateFloat16PlaceholderGroup(num_queries, 32, 1024);
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroup<milvus::Float16Vector>(num_queries, 32, 1024);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
Timestamp timestamp = 1000000;
|
||||
|
@ -274,7 +275,8 @@ TEST(Float16, ExecWithPredicate) {
|
|||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreateFloat16PlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroup<milvus::Float16Vector>(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
|
@ -354,7 +356,8 @@ TEST(BFloat16, ExecWithoutPredicateFlat) {
|
|||
auto vec_ptr = dataset.get_col<bfloat16>(vec_fid);
|
||||
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreateBFloat16PlaceholderGroup(num_queries, 32, 1024);
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroup<milvus::BFloat16Vector>(num_queries, 32, 1024);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
Timestamp timestamp = 1000000;
|
||||
|
@ -510,7 +513,8 @@ TEST(BFloat16, ExecWithPredicate) {
|
|||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreateBFloat16PlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroup<milvus::BFloat16Vector>(num_queries, 16, 1024);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
Timestamp timestamp = 1000000;
|
||||
|
|
|
@ -725,7 +725,7 @@ TEST(Query, ExecWithPredicateBinary) {
|
|||
auto plan =
|
||||
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(
|
||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob<milvus::BinaryVector>(
|
||||
num_queries, 512, vec_ptr.data() + 1024 * 512 / 8);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
|
|
@ -743,26 +743,6 @@ DataGenForJsonArray(SchemaPtr schema,
|
|||
return res;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
|
||||
namespace ser = milvus::proto::common;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::FloatVector);
|
||||
std::normal_distribution<double> dis(0, 1);
|
||||
std::default_random_engine e(seed);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<float> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(dis(e));
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreatePlaceholderGroup(int64_t num_queries,
|
||||
int dim,
|
||||
|
@ -782,148 +762,57 @@ CreatePlaceholderGroup(int64_t num_queries,
|
|||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const float* src) {
|
||||
template <class TraitType = milvus::FloatVector>
|
||||
auto
|
||||
CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
|
||||
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
|
||||
assert(dim % 8 == 0);
|
||||
}
|
||||
namespace ser = milvus::proto::common;
|
||||
GET_ELEM_TYPE_FOR_VECTOR_TRAIT
|
||||
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::FloatVector);
|
||||
value->set_type(TraitType::placeholder_type);
|
||||
// TODO caiyd: need update for Int8Vector
|
||||
std::normal_distribution<double> dis(0, 1);
|
||||
std::default_random_engine e(seed);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<elem_type> vec;
|
||||
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
|
||||
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
|
||||
vec.push_back(e());
|
||||
} else {
|
||||
vec.push_back(elem_type(dis(e)));
|
||||
}
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
template <class TraitType = milvus::FloatVector>
|
||||
inline auto
|
||||
CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) {
|
||||
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
|
||||
assert(dim % 8 == 0);
|
||||
}
|
||||
namespace ser = milvus::proto::common;
|
||||
GET_ELEM_TYPE_FOR_VECTOR_TRAIT
|
||||
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(TraitType::placeholder_type);
|
||||
int64_t src_index = 0;
|
||||
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<float> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(src[src_index++]);
|
||||
std::vector<elem_type> vec;
|
||||
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
|
||||
vec.push_back(((elem_type*)src)[src_index++]);
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreateBinaryPlaceholderGroup(int64_t num_queries,
|
||||
int64_t dim,
|
||||
int64_t seed = 42) {
|
||||
assert(dim % 8 == 0);
|
||||
namespace ser = milvus::proto::common;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::BinaryVector);
|
||||
std::default_random_engine e(seed);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<uint8_t> vec;
|
||||
for (int d = 0; d < dim / 8; ++d) {
|
||||
vec.push_back(e());
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size());
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreateBinaryPlaceholderGroupFromBlob(int64_t num_queries,
|
||||
int64_t dim,
|
||||
const uint8_t* ptr) {
|
||||
assert(dim % 8 == 0);
|
||||
namespace ser = milvus::proto::common;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::BinaryVector);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<uint8_t> vec;
|
||||
for (int d = 0; d < dim / 8; ++d) {
|
||||
vec.push_back(*ptr);
|
||||
++ptr;
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size());
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreateFloat16PlaceholderGroup(int64_t num_queries,
|
||||
int64_t dim,
|
||||
int64_t seed = 42) {
|
||||
namespace ser = milvus::proto::common;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::Float16Vector);
|
||||
std::normal_distribution<double> dis(0, 1);
|
||||
std::default_random_engine e(seed);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<float16> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(float16(dis(e)));
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float16));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreateFloat16PlaceholderGroupFromBlob(int64_t num_queries,
|
||||
int64_t dim,
|
||||
const float16* ptr) {
|
||||
namespace ser = milvus::proto::common;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::Float16Vector);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<float16> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(*ptr);
|
||||
++ptr;
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float16));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreateBFloat16PlaceholderGroup(int64_t num_queries,
|
||||
int64_t dim,
|
||||
int64_t seed = 42) {
|
||||
namespace ser = milvus::proto::common;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::BFloat16Vector);
|
||||
std::normal_distribution<double> dis(0, 1);
|
||||
std::default_random_engine e(seed);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<bfloat16> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(bfloat16(dis(e)));
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(bfloat16));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreateBFloat16PlaceholderGroupFromBlob(int64_t num_queries,
|
||||
int64_t dim,
|
||||
const bfloat16* ptr) {
|
||||
namespace ser = milvus::proto::common;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::BFloat16Vector);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<bfloat16> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(*ptr);
|
||||
++ptr;
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(bfloat16));
|
||||
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "common/Types.h"
|
||||
#include "common/type_c.h"
|
||||
#include "common/VectorTrait.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "segcore/Collection.h"
|
||||
#include "segcore/reduce/Reduce.h"
|
||||
|
@ -32,7 +33,6 @@
|
|||
#include "futures/future_c.h"
|
||||
#include "DataGen.h"
|
||||
#include "PbHelper.h"
|
||||
#include "c_api_test_utils.h"
|
||||
#include "indexbuilder_test_utils.h"
|
||||
|
||||
using namespace milvus;
|
||||
|
@ -66,26 +66,30 @@ generate_max_float_query_data(int all_nq, int max_float_nq) {
|
|||
return blob;
|
||||
}
|
||||
|
||||
template <class TraitType = milvus::FloatVector>
|
||||
std::string
|
||||
generate_query_data(int nq) {
|
||||
namespace ser = milvus::proto::common;
|
||||
GET_ELEM_TYPE_FOR_VECTOR_TRAIT
|
||||
|
||||
std::default_random_engine e(67);
|
||||
int dim = DIM;
|
||||
std::normal_distribution<double> dis(0.0, 1.0);
|
||||
std::uniform_int_distribution<int8_t> dis(-128, 127);
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::FloatVector);
|
||||
value->set_type(TraitType::placeholder_type);
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
std::vector<float> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(dis(e));
|
||||
std::vector<elem_type> vec;
|
||||
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
|
||||
vec.push_back((elem_type)dis(e));
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
|
||||
}
|
||||
auto blob = raw_group.SerializeAsString();
|
||||
return blob;
|
||||
}
|
||||
|
||||
void
|
||||
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results,
|
||||
int group_size = 1) {
|
||||
|
@ -117,13 +121,14 @@ CheckSearchResultDuplicate(const std::vector<CSearchResult>& results,
|
|||
}
|
||||
}
|
||||
|
||||
const char*
|
||||
template <class TraitType = milvus::FloatVector>
|
||||
const std::string
|
||||
get_default_schema_config() {
|
||||
static std::string conf = R"(name: "default-collection"
|
||||
auto fmt = boost::format(R"(name: "default-collection"
|
||||
fields: <
|
||||
fieldID: 100
|
||||
name: "fakevec"
|
||||
data_type: FloatVector
|
||||
data_type: %1%
|
||||
type_params: <
|
||||
key: "dim"
|
||||
value: "16"
|
||||
|
@ -138,9 +143,9 @@ get_default_schema_config() {
|
|||
name: "age"
|
||||
data_type: Int64
|
||||
is_primary_key: true
|
||||
>)";
|
||||
static std::string fake_conf = "";
|
||||
return conf.c_str();
|
||||
>)") %
|
||||
(int(TraitType::data_type));
|
||||
return fmt.str();
|
||||
}
|
||||
|
||||
const char*
|
||||
|
|
Loading…
Reference in New Issue