diff --git a/core/src/db/Utils.cpp b/core/src/db/Utils.cpp index 8aeb52df45..f08705874e 100644 --- a/core/src/db/Utils.cpp +++ b/core/src/db/Utils.cpp @@ -278,7 +278,7 @@ SplitChunk(const DataChunkPtr& chunk, int64_t segment_row_limit, std::vectorindex_type() + "_Data"; auto real_idx = dynamic_cast(index_.get()); if (real_idx == nullptr) { - KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Serialize!"); + KNOWHERE_THROW_MSG("index is not a faiss::IndexRHNSWFlat"); } - auto storage_index = dynamic_cast(real_idx->storage); - faiss::write_index(storage_index, &writer); - std::shared_ptr data(writer.data_); - res_set.Append(writer.name, data, writer.rp); + int64_t meta_info[3] = {real_idx->storage->metric_type, real_idx->storage->d, real_idx->storage->ntotal}; + auto meta_space = new uint8_t[sizeof(meta_info)]; + memcpy(meta_space, meta_info, sizeof(meta_info)); + std::shared_ptr space_sp(meta_space, std::default_delete()); + res_set.Append("META", space_sp, sizeof(meta_info)); + if (config.contains(INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) { Disassemble(config[INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get() * 1024 * 1024, res_set); } @@ -66,18 +66,16 @@ IndexRHNSWFlat::Load(const BinarySet& index_binary) { try { Assemble(const_cast(index_binary)); IndexRHNSW::Load(index_binary); - MemoryIOReader reader; - reader.name = this->index_type() + "_Data"; - auto binary = index_binary.GetByName(reader.name); - reader.total = static_cast(binary->size); - reader.data_ = binary->data.get(); + int64_t meta_info[3]; // = {metric_type, dim, ntotal} + auto meta_data = index_binary.GetByName("META"); + memcpy(meta_info, meta_data->data.get(), meta_data->size); auto real_idx = dynamic_cast(index_.get()); - if (real_idx == nullptr) { - KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Load!"); - } - real_idx->storage = faiss::read_index(&reader); + real_idx->storage = + new faiss::IndexFlat(static_cast(meta_info[1]), static_cast(meta_info[0])); + auto binary_data = index_binary.GetByName(RAW_DATA); + real_idx->storage->add(meta_info[2], reinterpret_cast(binary_data->data.get())); real_idx->init_hnsw(); } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); diff --git a/core/src/index/unittest/test_rhnsw_flat.cpp b/core/src/index/unittest/test_rhnsw_flat.cpp index 3007f15249..5dcb5d4afa 100644 --- a/core/src/index/unittest/test_rhnsw_flat.cpp +++ b/core/src/index/unittest/test_rhnsw_flat.cpp @@ -59,6 +59,13 @@ TEST_P(RHNSWFlatTest, HNSW_basic) { // Serialize and Load before Query milvus::knowhere::BinarySet bs = index_->Serialize(conf); + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); auto tmp_index = std::make_shared(); tmp_index->Load(bs); @@ -125,30 +132,37 @@ TEST_P(RHNSWFlatTest, HNSW_serialize) { auto binaryset = index_->Serialize(conf); std::string index_type = index_->index_type(); std::string idx_name = index_type + "_Index"; - std::string dat_name = index_type + "_Data"; + std::string met_name = "META"; if (binaryset.binary_map_.find(idx_name) == binaryset.binary_map_.end()) { std::cout << "no idx!" << std::endl; } - if (binaryset.binary_map_.find(dat_name) == binaryset.binary_map_.end()) { - std::cout << "no dat!" << std::endl; + if (binaryset.binary_map_.find(met_name) == binaryset.binary_map_.end()) { + std::cout << "no met!" << std::endl; } auto bin_idx = binaryset.GetByName(idx_name); - auto bin_dat = binaryset.GetByName(dat_name); + auto bin_met = binaryset.GetByName(met_name); std::string filename_idx = "/tmp/RHNSWFlat_test_serialize_idx.bin"; - std::string filename_dat = "/tmp/RHNSWFlat_test_serialize_dat.bin"; + std::string filename_met = "/tmp/RHNSWFlat_test_serialize_met.bin"; auto load_idx = new uint8_t[bin_idx->size]; - auto load_dat = new uint8_t[bin_dat->size]; + auto load_met = new uint8_t[bin_met->size]; serialize(filename_idx, bin_idx, load_idx); - serialize(filename_dat, bin_dat, load_dat); + serialize(filename_met, bin_met, load_met); binaryset.clear(); auto new_idx = std::make_shared(); - std::shared_ptr dat(load_dat); + std::shared_ptr met(load_met); std::shared_ptr idx(load_idx); binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size); - binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size); + binaryset.Append("META", met, bin_met->size); + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + binaryset.Append(RAW_DATA, bptr); new_idx->Load(binaryset); EXPECT_EQ(new_idx->Count(), nb); EXPECT_EQ(new_idx->Dim(), dim); @@ -164,6 +178,13 @@ TEST_P(RHNSWFlatTest, HNSW_slice) { index_->AddWithoutIds(base_dataset, conf); auto binaryset = index_->Serialize(conf); auto new_idx = std::make_shared(); + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + binaryset.Append(RAW_DATA, bptr); new_idx->Load(binaryset); EXPECT_EQ(new_idx->Count(), nb); EXPECT_EQ(new_idx->Dim(), dim);