diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 6e69ddac17..729f3af09a 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -139,6 +139,16 @@ GetDataTypeSize(DataType data_type, int dim = 1) { } } +template +inline size_t +GetVecRowSize(int64_t dim) { + if constexpr (std::is_same_v) { + return (dim / 8) * sizeof(bin1); + } else { + return dim * sizeof(T); + } +} + // TODO: use magic_enum when available inline std::string GetDataTypeName(DataType data_type) { @@ -393,6 +403,18 @@ IndexIsSparse(const IndexType& index_type) { index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND; } +inline bool +IsFloatVectorMetricType(const MetricType& metric_type) { + return metric_type == knowhere::metric::L2 || + metric_type == knowhere::metric::IP || + metric_type == knowhere::metric::COSINE; +} + +inline bool +IsBinaryVectorMetricType(const MetricType& metric_type) { + return !IsFloatVectorMetricType(metric_type); +} + // Plus 1 because we can't use greater(>) symbol constexpr size_t REF_SIZE_THRESHOLD = 16 + 1; diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index ac4a89f933..5e45930675 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -274,7 +274,7 @@ IndexFactory::CreateVectorIndex( index_type, metric_type, version, file_manager_context); } case DataType::VECTOR_BINARY: { - return std::make_unique>( + return std::make_unique>( index_type, metric_type, version, file_manager_context); } case DataType::VECTOR_FLOAT16: { @@ -360,7 +360,7 @@ IndexFactory::CreateVectorIndex( create_index_info, file_manager_context, space); } case DataType::VECTOR_BINARY: { - return std::make_unique>( + return std::make_unique>( create_index_info, file_manager_context, space); } case DataType::VECTOR_FLOAT16: { diff --git a/internal/core/src/index/Utils.h b/internal/core/src/index/Utils.h index 50c70d8d52..1444eeeac6 100644 --- a/internal/core/src/index/Utils.h +++ b/internal/core/src/index/Utils.h @@ -91,6 +91,20 @@ SetValueToConfig(Config& cfg, const std::string& key, const T value) { cfg[key] = value; } +template +inline void +CheckMetricTypeSupport(const MetricType& metric_type) { + if constexpr (std::is_same_v) { + AssertInfo( + IsBinaryVectorMetricType(metric_type), + "binary vector does not float vector metric type: " + metric_type); + } else { + AssertInfo( + IsFloatVectorMetricType(metric_type), + "float vector does not binary vector metric type: " + metric_type); + } +} + int64_t GetDimFromConfig(const Config& config); diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index 344cdf565e..d6ca94b1df 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -42,6 +42,7 @@ VectorDiskAnnIndex::VectorDiskAnnIndex( const IndexVersion& version, const storage::FileManagerContext& file_manager_context) : VectorIndex(index_type, metric_type) { + CheckMetricTypeSupport(metric_type); file_manager_ = std::make_shared(file_manager_context); AssertInfo(file_manager_ != nullptr, "create file manager failed!"); @@ -80,6 +81,7 @@ VectorDiskAnnIndex::VectorDiskAnnIndex( std::shared_ptr space, const storage::FileManagerContext& file_manager_context) : space_(space), VectorIndex(index_type, metric_type) { + CheckMetricTypeSupport(metric_type); file_manager_ = std::make_shared( file_manager_context, file_manager_context.space_); AssertInfo(file_manager_ != nullptr, "create file manager failed!"); @@ -316,7 +318,7 @@ VectorDiskAnnIndex::BuildWithDataset(const DatasetPtr& dataset, local_chunk_manager->Write(local_data_path, offset, &dim, sizeof(dim)); offset += sizeof(dim); - auto data_size = num * dim * sizeof(T); + size_t data_size = static_cast(num) * milvus::GetVecRowSize(dim); auto raw_data = const_cast(milvus::GetDatasetTensor(dataset)); local_chunk_manager->Write(local_data_path, offset, raw_data, data_size); @@ -448,12 +450,7 @@ VectorDiskAnnIndex::GetVector(const DatasetPtr dataset) const { auto tensor = res.value()->GetTensor(); auto row_num = res.value()->GetRows(); auto dim = res.value()->GetDim(); - int64_t data_size; - if constexpr (std::is_same_v) { - data_size = dim / 8 * row_num; - } else { - data_size = dim * row_num * sizeof(T); - } + int64_t data_size = milvus::GetVecRowSize(dim) * row_num; std::vector raw_data; raw_data.resize(data_size); memcpy(raw_data.data(), tensor, data_size); diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 580c568e10..149e212d84 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -61,6 +61,7 @@ VectorMemIndex::VectorMemIndex( const IndexVersion& version, const storage::FileManagerContext& file_manager_context) : VectorIndex(index_type, metric_type) { + CheckMetricTypeSupport(metric_type); AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type); if (file_manager_context.Valid()) { @@ -90,6 +91,7 @@ VectorMemIndex::VectorMemIndex( : VectorIndex(create_index_info.index_type, create_index_info.metric_type), space_(space), create_index_info_(create_index_info) { + CheckMetricTypeSupport(create_index_info.metric_type); AssertInfo(!is_unsupported(create_index_info.index_type, create_index_info.metric_type), create_index_info.index_type + @@ -668,12 +670,7 @@ VectorMemIndex::GetVector(const DatasetPtr dataset) const { auto tensor = res.value()->GetTensor(); auto row_num = res.value()->GetRows(); auto dim = res.value()->GetDim(); - int64_t data_size; - if constexpr (std::is_same_v) { - data_size = dim / 8 * row_num; - } else { - data_size = dim * row_num * sizeof(T); - } + int64_t data_size = milvus::GetVecRowSize(dim) * row_num; std::vector raw_data; raw_data.resize(data_size); memcpy(raw_data.data(), tensor, data_size); diff --git a/internal/core/src/storage/DiskFileManagerImpl.cpp b/internal/core/src/storage/DiskFileManagerImpl.cpp index 1d780a5225..5eb6c8a311 100644 --- a/internal/core/src/storage/DiskFileManagerImpl.cpp +++ b/internal/core/src/storage/DiskFileManagerImpl.cpp @@ -414,7 +414,7 @@ DiskFileManagerImpl::CacheRawDataToDisk( field_data->FillFieldData(col_data); dim = field_data->get_dim(); auto data_size = - field_data->get_num_rows() * index_meta_.dim * sizeof(DataType); + field_data->get_num_rows() * milvus::GetVecRowSize(dim); local_chunk_manager->Write(local_data_path, write_offset, const_cast(field_data->Data()), @@ -516,8 +516,8 @@ DiskFileManagerImpl::CacheRawDataToDisk(std::vector remote_files) { "inconsistent dim value in multi binlogs!"); dim = field_data->get_dim(); - auto data_size = - field_data->get_num_rows() * dim * sizeof(DataType); + auto data_size = field_data->get_num_rows() * + milvus::GetVecRowSize(dim); local_chunk_manager->Write( local_data_path, write_offset, diff --git a/internal/core/unittest/test_growing.cpp b/internal/core/unittest/test_growing.cpp index 20a42a143f..3a15dbbf31 100644 --- a/internal/core/unittest/test_growing.cpp +++ b/internal/core/unittest/test_growing.cpp @@ -103,8 +103,8 @@ class GrowingTest public: void SetUp() override { - auto index_type = std::get<0>(GetParam()); - auto metric_type = std::get<1>(GetParam()); + index_type = std::get<0>(GetParam()); + metric_type = std::get<1>(GetParam()); if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT || index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC) { data_type = DataType::VECTOR_FLOAT; diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index c02f427736..3fa448f02d 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -329,8 +329,7 @@ class IndexTest : public ::testing::TestWithParam { index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) { is_sparse = true; vec_field_data_type = milvus::DataType::VECTOR_SPARSE_FLOAT; - } else if (index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT || - index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP) { + } else if (IsBinaryVectorMetricType(metric_type)) { is_binary = true; vec_field_data_type = milvus::DataType::VECTOR_BINARY; } else { diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index 53c6b89f86..65876b64ed 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -141,7 +141,7 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil } errDimMismatch := func(fieldName string, dataDim int64, schemaDim int64) error { msg := fmt.Sprintf("the dim (%d) of field data(%s) is not equal to schema dim (%d)", dataDim, fieldName, schemaDim) - return merr.WrapErrParameterInvalid(dataDim, schemaDim, msg) + return merr.WrapErrParameterInvalid(schemaDim, dataDim, msg) } for _, field := range data { switch field.GetType() {