mirror of https://github.com/milvus-io/milvus.git
Add support for getting vectors by ids (#23450)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/23610/head
parent
897ed620e4
commit
092d743917
3
Makefile
3
Makefile
|
@ -318,6 +318,9 @@ rpm: install
|
||||||
@cp -r build/rpm/services ~/rpmbuild/BUILD/
|
@cp -r build/rpm/services ~/rpmbuild/BUILD/
|
||||||
@QA_RPATHS="$$[ 0x001|0x0002|0x0020 ]" rpmbuild -ba ./build/rpm/milvus.spec
|
@QA_RPATHS="$$[ 0x001|0x0002|0x0020 ]" rpmbuild -ba ./build/rpm/milvus.spec
|
||||||
|
|
||||||
|
mock-proxy:
|
||||||
|
mockery --name=ProxyComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy.go --structname=Proxy --with-expecter
|
||||||
|
|
||||||
mock-datanode:
|
mock-datanode:
|
||||||
mockery --name=DataNode --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode.go --with-expecter
|
mockery --name=DataNode --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode.go --with-expecter
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,16 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline DatasetPtr
|
||||||
|
GenIdsDataset(const int64_t count, const int64_t* ids) {
|
||||||
|
auto ret_ds = std::make_shared<Dataset>();
|
||||||
|
ret_ds->SetRows(count);
|
||||||
|
ret_ds->SetDim(1);
|
||||||
|
ret_ds->SetIds(ids);
|
||||||
|
ret_ds->SetIsOwner(false);
|
||||||
|
return ret_ds;
|
||||||
|
}
|
||||||
|
|
||||||
inline DatasetPtr
|
inline DatasetPtr
|
||||||
GenResultDataset(const int64_t nq,
|
GenResultDataset(const int64_t nq,
|
||||||
const int64_t topk,
|
const int64_t topk,
|
||||||
|
|
|
@ -230,6 +230,38 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
const bool
|
||||||
|
VectorDiskAnnIndex<T>::HasRawData() const {
|
||||||
|
return index_.HasRawData(GetMetricType());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
const std::vector<uint8_t>
|
||||||
|
VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset,
|
||||||
|
const Config& config) const {
|
||||||
|
auto res = index_.GetVectorByIds(*dataset, config);
|
||||||
|
if (!res.has_value()) {
|
||||||
|
PanicCodeInfo(
|
||||||
|
ErrorCodeEnum::UnexpectedError,
|
||||||
|
"failed to get vector, " + MatchKnowhereError(res.error()));
|
||||||
|
}
|
||||||
|
auto index_type = GetIndexType();
|
||||||
|
auto tensor = res.value()->GetTensor();
|
||||||
|
auto row_num = res.value()->GetRows();
|
||||||
|
auto dim = res.value()->GetDim();
|
||||||
|
int64_t data_size;
|
||||||
|
if (is_in_bin_list(index_type)) {
|
||||||
|
data_size = dim / 8 * row_num;
|
||||||
|
} else {
|
||||||
|
data_size = dim * row_num * sizeof(float);
|
||||||
|
}
|
||||||
|
std::vector<uint8_t> raw_data;
|
||||||
|
raw_data.resize(data_size);
|
||||||
|
memcpy(raw_data.data(), tensor, data_size);
|
||||||
|
return raw_data;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void
|
void
|
||||||
VectorDiskAnnIndex<T>::CleanLocalData() {
|
VectorDiskAnnIndex<T>::CleanLocalData() {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "index/VectorIndex.h"
|
#include "index/VectorIndex.h"
|
||||||
#include "storage/DiskFileManagerImpl.h"
|
#include "storage/DiskFileManagerImpl.h"
|
||||||
|
@ -60,6 +61,13 @@ class VectorDiskAnnIndex : public VectorIndex {
|
||||||
const SearchInfo& search_info,
|
const SearchInfo& search_info,
|
||||||
const BitsetView& bitset) override;
|
const BitsetView& bitset) override;
|
||||||
|
|
||||||
|
const bool
|
||||||
|
HasRawData() const override;
|
||||||
|
|
||||||
|
const std::vector<uint8_t>
|
||||||
|
GetVector(const DatasetPtr dataset,
|
||||||
|
const Config& config = {}) const override;
|
||||||
|
|
||||||
void
|
void
|
||||||
CleanLocalData() override;
|
CleanLocalData() override;
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
#include <boost/dynamic_bitset.hpp>
|
#include <boost/dynamic_bitset.hpp>
|
||||||
|
|
||||||
#include "knowhere/factory.h"
|
#include "knowhere/factory.h"
|
||||||
|
@ -50,6 +51,12 @@ class VectorIndex : public IndexBase {
|
||||||
const SearchInfo& search_info,
|
const SearchInfo& search_info,
|
||||||
const BitsetView& bitset) = 0;
|
const BitsetView& bitset) = 0;
|
||||||
|
|
||||||
|
virtual const bool
|
||||||
|
HasRawData() const = 0;
|
||||||
|
|
||||||
|
virtual const std::vector<uint8_t>
|
||||||
|
GetVector(const DatasetPtr dataset, const Config& config = {}) const = 0;
|
||||||
|
|
||||||
IndexType
|
IndexType
|
||||||
GetIndexType() const {
|
GetIndexType() const {
|
||||||
return index_type_;
|
return index_type_;
|
||||||
|
|
|
@ -145,4 +145,34 @@ VectorMemIndex::Query(const DatasetPtr dataset,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const bool
|
||||||
|
VectorMemIndex::HasRawData() const {
|
||||||
|
return index_.HasRawData(GetMetricType());
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<uint8_t>
|
||||||
|
VectorMemIndex::GetVector(const DatasetPtr dataset,
|
||||||
|
const Config& config) const {
|
||||||
|
auto res = index_.GetVectorByIds(*dataset, config);
|
||||||
|
if (!res.has_value()) {
|
||||||
|
PanicCodeInfo(
|
||||||
|
ErrorCodeEnum::UnexpectedError,
|
||||||
|
"failed to get vector, " + MatchKnowhereError(res.error()));
|
||||||
|
}
|
||||||
|
auto index_type = GetIndexType();
|
||||||
|
auto tensor = res.value()->GetTensor();
|
||||||
|
auto row_num = res.value()->GetRows();
|
||||||
|
auto dim = res.value()->GetDim();
|
||||||
|
int64_t data_size;
|
||||||
|
if (is_in_bin_list(index_type)) {
|
||||||
|
data_size = dim / 8 * row_num;
|
||||||
|
} else {
|
||||||
|
data_size = dim * row_num * sizeof(float);
|
||||||
|
}
|
||||||
|
std::vector<uint8_t> raw_data;
|
||||||
|
raw_data.resize(data_size);
|
||||||
|
memcpy(raw_data.data(), tensor, data_size);
|
||||||
|
return raw_data;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace milvus::index
|
} // namespace milvus::index
|
||||||
|
|
|
@ -51,6 +51,13 @@ class VectorMemIndex : public VectorIndex {
|
||||||
const SearchInfo& search_info,
|
const SearchInfo& search_info,
|
||||||
const BitsetView& bitset) override;
|
const BitsetView& bitset) override;
|
||||||
|
|
||||||
|
const bool
|
||||||
|
HasRawData() const override;
|
||||||
|
|
||||||
|
const std::vector<uint8_t>
|
||||||
|
GetVector(const DatasetPtr dataset,
|
||||||
|
const Config& config = {}) const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Config config_;
|
Config config_;
|
||||||
knowhere::Index<knowhere::IndexNode> index_;
|
knowhere::Index<knowhere::IndexNode> index_;
|
||||||
|
|
|
@ -223,6 +223,11 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool
|
||||||
|
HasRawData(int64_t field_id) const override {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int64_t
|
int64_t
|
||||||
num_chunk() const override;
|
num_chunk() const override;
|
||||||
|
|
|
@ -88,6 +88,9 @@ class SegmentInterface {
|
||||||
|
|
||||||
virtual SegmentType
|
virtual SegmentType
|
||||||
type() const = 0;
|
type() const = 0;
|
||||||
|
|
||||||
|
virtual bool
|
||||||
|
HasRawData(int64_t field_id) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// internal API for DSL calculation
|
// internal API for DSL calculation
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "query/ScalarIndex.h"
|
#include "query/ScalarIndex.h"
|
||||||
#include "query/SearchBruteForce.h"
|
#include "query/SearchBruteForce.h"
|
||||||
#include "query/SearchOnSealed.h"
|
#include "query/SearchOnSealed.h"
|
||||||
|
#include "index/Utils.h"
|
||||||
|
|
||||||
namespace milvus::segcore {
|
namespace milvus::segcore {
|
||||||
|
|
||||||
|
@ -475,6 +476,35 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<DataArray>
|
||||||
|
SegmentSealedImpl::get_vector(FieldId field_id,
|
||||||
|
const int64_t* ids,
|
||||||
|
int64_t count) const {
|
||||||
|
auto& filed_meta = schema_->operator[](field_id);
|
||||||
|
AssertInfo(filed_meta.is_vector(), "vector field is not vector type");
|
||||||
|
|
||||||
|
if (get_bit(index_ready_bitset_, field_id)) {
|
||||||
|
AssertInfo(vector_indexings_.is_ready(field_id),
|
||||||
|
"vector index is not ready");
|
||||||
|
auto field_indexing = vector_indexings_.get_field_indexing(field_id);
|
||||||
|
auto vec_index =
|
||||||
|
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
||||||
|
|
||||||
|
auto index_type = vec_index->GetIndexType();
|
||||||
|
auto metric_type = vec_index->GetMetricType();
|
||||||
|
auto has_raw_data = vec_index->HasRawData();
|
||||||
|
|
||||||
|
if (has_raw_data) {
|
||||||
|
auto ids_ds = GenIdsDataset(count, ids);
|
||||||
|
auto& vector = vec_index->GetVector(ids_ds);
|
||||||
|
return segcore::CreateVectorDataArrayFrom(
|
||||||
|
vector.data(), count, filed_meta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fill_with_empty(field_id, count);
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
|
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
|
||||||
if (SystemProperty::Instance().IsSystem(field_id)) {
|
if (SystemProperty::Instance().IsSystem(field_id)) {
|
||||||
|
@ -666,9 +696,7 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id,
|
||||||
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
|
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: knowhere support reverse data from vector index
|
return get_vector(field_id, seg_offsets, count);
|
||||||
// Now, real data will be filled in data array using chunk manager
|
|
||||||
return fill_with_empty(field_id, count);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Assert(get_bit(field_data_ready_bitset_, field_id));
|
Assert(get_bit(field_data_ready_bitset_, field_id));
|
||||||
|
@ -783,6 +811,24 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool
|
||||||
|
SegmentSealedImpl::HasRawData(int64_t field_id) const {
|
||||||
|
std::shared_lock lck(mutex_);
|
||||||
|
auto fieldID = FieldId(field_id);
|
||||||
|
const auto& field_meta = schema_->operator[](fieldID);
|
||||||
|
if (datatype_is_vector(field_meta.get_data_type())) {
|
||||||
|
if (get_bit(index_ready_bitset_, fieldID)) {
|
||||||
|
AssertInfo(vector_indexings_.is_ready(fieldID),
|
||||||
|
"vector index is not ready");
|
||||||
|
auto field_indexing = vector_indexings_.get_field_indexing(fieldID);
|
||||||
|
auto vec_index = dynamic_cast<index::VectorIndex*>(
|
||||||
|
field_indexing->indexing_.get());
|
||||||
|
return vec_index->HasRawData();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
||||||
SegmentSealedImpl::search_ids(const IdArray& id_array,
|
SegmentSealedImpl::search_ids(const IdArray& id_array,
|
||||||
Timestamp timestamp) const {
|
Timestamp timestamp) const {
|
||||||
|
|
|
@ -61,6 +61,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
||||||
return id_;
|
return id_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool
|
||||||
|
HasRawData(int64_t field_id) const override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
int64_t
|
int64_t
|
||||||
GetMemoryUsageInBytes() const override;
|
GetMemoryUsageInBytes() const override;
|
||||||
|
@ -74,6 +77,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
||||||
const Schema&
|
const Schema&
|
||||||
get_schema() const override;
|
get_schema() const override;
|
||||||
|
|
||||||
|
std::unique_ptr<DataArray>
|
||||||
|
get_vector(FieldId field_id, const int64_t* ids, int64_t count) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
int64_t
|
int64_t
|
||||||
num_chunk_index(FieldId field_id) const override;
|
num_chunk_index(FieldId field_id) const override;
|
||||||
|
|
|
@ -164,6 +164,13 @@ GetRealCount(CSegmentInterface c_segment) {
|
||||||
return segment->get_real_count();
|
return segment->get_real_count();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool
|
||||||
|
HasRawData(CSegmentInterface c_segment, int64_t field_id) {
|
||||||
|
auto segment =
|
||||||
|
reinterpret_cast<milvus::segcore::SegmentInterface*>(c_segment);
|
||||||
|
return segment->HasRawData(field_id);
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////// interfaces for growing segment //////////////////////////////
|
////////////////////////////// interfaces for growing segment //////////////////////////////
|
||||||
CStatus
|
CStatus
|
||||||
Insert(CSegmentInterface c_segment,
|
Insert(CSegmentInterface c_segment,
|
||||||
|
|
|
@ -67,6 +67,9 @@ GetDeletedCount(CSegmentInterface c_segment);
|
||||||
int64_t
|
int64_t
|
||||||
GetRealCount(CSegmentInterface c_segment);
|
GetRealCount(CSegmentInterface c_segment);
|
||||||
|
|
||||||
|
bool
|
||||||
|
HasRawData(CSegmentInterface c_segment, int64_t field_id);
|
||||||
|
|
||||||
////////////////////////////// interfaces for growing segment //////////////////////////////
|
////////////////////////////// interfaces for growing segment //////////////////////////////
|
||||||
CStatus
|
CStatus
|
||||||
Insert(CSegmentInterface c_segment,
|
Insert(CSegmentInterface c_segment,
|
||||||
|
|
|
@ -449,6 +449,91 @@ TEST_P(IndexTest, BuildAndQuery) {
|
||||||
vec_index->Query(xq_dataset, search_info, nullptr);
|
vec_index->Query(xq_dataset, search_info, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(IndexTest, GetVector) {
|
||||||
|
milvus::index::CreateIndexInfo create_index_info;
|
||||||
|
create_index_info.index_type = index_type;
|
||||||
|
create_index_info.metric_type = metric_type;
|
||||||
|
create_index_info.field_type = vec_field_data_type;
|
||||||
|
index::IndexBasePtr index;
|
||||||
|
|
||||||
|
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
|
||||||
|
#ifdef BUILD_DISK_ANN
|
||||||
|
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||||
|
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||||
|
auto file_manager =
|
||||||
|
std::make_shared<milvus::storage::DiskFileManagerImpl>(
|
||||||
|
field_data_meta, index_meta, storage_config_);
|
||||||
|
index = milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||||
|
create_index_info, file_manager);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
index = milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||||
|
create_index_info, nullptr);
|
||||||
|
}
|
||||||
|
ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
|
||||||
|
milvus::index::IndexBasePtr new_index;
|
||||||
|
milvus::index::VectorIndex* vec_index = nullptr;
|
||||||
|
|
||||||
|
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
|
||||||
|
#ifdef BUILD_DISK_ANN
|
||||||
|
// TODO ::diskann.query need load first, ugly
|
||||||
|
auto binary_set = index->Serialize(milvus::Config{});
|
||||||
|
index.reset();
|
||||||
|
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||||
|
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||||
|
auto file_manager =
|
||||||
|
std::make_shared<milvus::storage::DiskFileManagerImpl>(
|
||||||
|
field_data_meta, index_meta, storage_config_);
|
||||||
|
new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||||
|
create_index_info, file_manager);
|
||||||
|
|
||||||
|
vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
|
||||||
|
|
||||||
|
std::vector<std::string> index_files;
|
||||||
|
for (auto& binary : binary_set.binary_map_) {
|
||||||
|
index_files.emplace_back(binary.first);
|
||||||
|
}
|
||||||
|
load_conf["index_files"] = index_files;
|
||||||
|
vec_index->Load(binary_set, load_conf);
|
||||||
|
EXPECT_EQ(vec_index->Count(), NB);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
vec_index = dynamic_cast<milvus::index::VectorIndex*>(index.get());
|
||||||
|
}
|
||||||
|
EXPECT_EQ(vec_index->GetDim(), DIM);
|
||||||
|
EXPECT_EQ(vec_index->Count(), NB);
|
||||||
|
|
||||||
|
if (!vec_index->HasRawData()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ids_ds = GenRandomIds(NB);
|
||||||
|
auto results = vec_index->GetVector(ids_ds);
|
||||||
|
EXPECT_TRUE(results.size() > 0);
|
||||||
|
if (!is_binary) {
|
||||||
|
std::vector<float> result_vectors(results.size() / (sizeof(float)));
|
||||||
|
memcpy(result_vectors.data(), results.data(), results.size());
|
||||||
|
EXPECT_TRUE(result_vectors.size() == xb_data.size());
|
||||||
|
for (size_t i = 0; i < NB; ++i) {
|
||||||
|
auto id = ids_ds->GetIds()[i];
|
||||||
|
for (size_t j = 0; j < DIM; ++j) {
|
||||||
|
EXPECT_TRUE(result_vectors[i * DIM + j] ==
|
||||||
|
xb_data[id * DIM + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
EXPECT_TRUE(results.size() == xb_bin_data.size());
|
||||||
|
const auto data_bytes = DIM / 8;
|
||||||
|
for (size_t i = 0; i < NB; ++i) {
|
||||||
|
auto id = ids_ds->GetIds()[i];
|
||||||
|
for (size_t j = 0; j < data_bytes; ++j) {
|
||||||
|
EXPECT_TRUE(results[i * data_bytes + j] ==
|
||||||
|
xb_bin_data[id * data_bytes + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// #ifdef BUILD_DISK_ANN
|
// #ifdef BUILD_DISK_ANN
|
||||||
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
|
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
|
||||||
// int64_t NB = 10000;
|
// int64_t NB = 10000;
|
||||||
|
|
|
@ -1067,3 +1067,52 @@ TEST(Sealed, RealCount) {
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
ASSERT_EQ(0, segment->get_real_count());
|
ASSERT_EQ(0, segment->get_real_count());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Sealed, GetVector) {
|
||||||
|
auto dim = 16;
|
||||||
|
auto topK = 5;
|
||||||
|
auto N = ROW_COUNT;
|
||||||
|
auto metric_type = knowhere::metric::L2;
|
||||||
|
auto schema = std::make_shared<Schema>();
|
||||||
|
auto fakevec_id = schema->AddDebugField(
|
||||||
|
"fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
|
||||||
|
auto counter_id = schema->AddDebugField("counter", DataType::INT64);
|
||||||
|
auto double_id = schema->AddDebugField("double", DataType::DOUBLE);
|
||||||
|
auto nothing_id = schema->AddDebugField("nothing", DataType::INT32);
|
||||||
|
auto str_id = schema->AddDebugField("str", DataType::VARCHAR);
|
||||||
|
schema->AddDebugField("int8", DataType::INT8);
|
||||||
|
schema->AddDebugField("int16", DataType::INT16);
|
||||||
|
schema->AddDebugField("float", DataType::FLOAT);
|
||||||
|
schema->set_primary_field_id(counter_id);
|
||||||
|
|
||||||
|
auto dataset = DataGen(schema, N);
|
||||||
|
|
||||||
|
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||||
|
|
||||||
|
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||||
|
|
||||||
|
auto segment_sealed = CreateSealedSegment(schema);
|
||||||
|
|
||||||
|
LoadIndexInfo vec_info;
|
||||||
|
vec_info.field_id = fakevec_id.get();
|
||||||
|
vec_info.index = std::move(indexing);
|
||||||
|
vec_info.index_params["metric_type"] = knowhere::metric::L2;
|
||||||
|
segment_sealed->LoadIndex(vec_info);
|
||||||
|
|
||||||
|
auto segment = dynamic_cast<SegmentSealedImpl*>(segment_sealed.get());
|
||||||
|
|
||||||
|
auto has = segment->HasRawData(vec_info.field_id);
|
||||||
|
EXPECT_TRUE(has);
|
||||||
|
|
||||||
|
auto ids_ds = GenRandomIds(N);
|
||||||
|
auto result = segment->get_vector(fakevec_id, ids_ds->GetIds(), N);
|
||||||
|
|
||||||
|
auto vector = result.get()->mutable_vectors()->float_vector().data();
|
||||||
|
EXPECT_TRUE(vector.size() == fakevec.size());
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
auto id = ids_ds->GetIds()[i];
|
||||||
|
for (size_t j = 0; j < dim; ++j) {
|
||||||
|
EXPECT_TRUE(vector[i * dim + j] == fakevec[id * dim + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -599,4 +599,15 @@ GenPKs(const std::vector<int64_t>& pks) {
|
||||||
return GenPKs(pks.begin(), pks.end());
|
return GenPKs(pks.begin(), pks.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::shared_ptr<knowhere::DataSet>
|
||||||
|
GenRandomIds(int rows, int64_t seed = 42) {
|
||||||
|
std::mt19937 g(seed);
|
||||||
|
auto* ids = new int64_t[rows];
|
||||||
|
for (int i = 0; i < rows; ++i) ids[i] = i;
|
||||||
|
std::shuffle(ids, ids + rows, g);
|
||||||
|
auto ids_ds = GenIdsDataset(rows, ids);
|
||||||
|
ids_ds->SetIsOwner(true);
|
||||||
|
return ids_ds;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace milvus::segcore
|
} // namespace milvus::segcore
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -2411,9 +2411,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
||||||
ReqID: paramtable.GetNodeID(),
|
ReqID: paramtable.GetNodeID(),
|
||||||
},
|
},
|
||||||
request: request,
|
request: request,
|
||||||
qc: node.queryCoord,
|
|
||||||
tr: timerecord.NewTimeRecorder("search"),
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
shardMgr: node.shardMgr,
|
shardMgr: node.shardMgr,
|
||||||
|
qc: node.queryCoord,
|
||||||
|
node: node,
|
||||||
}
|
}
|
||||||
|
|
||||||
travelTs := request.TravelTimestamp
|
travelTs := request.TravelTimestamp
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||||
|
@ -492,7 +493,10 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
|
||||||
case *schemapb.IDs_IntId:
|
case *schemapb.IDs_IntId:
|
||||||
idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
|
idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
|
||||||
case *schemapb.IDs_StrId:
|
case *schemapb.IDs_StrId:
|
||||||
idsStr = strings.Trim(strings.Join(ids.GetStrId().GetData(), ", "), "[]")
|
strs := lo.Map(ids.GetStrId().GetData(), func(str string, _ int) string {
|
||||||
|
return fmt.Sprintf("\"%s\"", str)
|
||||||
|
})
|
||||||
|
idsStr = strings.Trim(strings.Join(strs, ", "), "[]")
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldName + " in [ " + idsStr + " ]"
|
return fieldName + " in [ " + idsStr + " ]"
|
||||||
|
|
|
@ -869,3 +869,28 @@ func Test_queryTask_createPlan(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryTask_IDs2Expr(t *testing.T) {
|
||||||
|
fieldName := "pk"
|
||||||
|
intIDs := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3, 4, 5},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stringIDs := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_StrId{
|
||||||
|
StrId: &schemapb.StringArray{
|
||||||
|
Data: []string{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
idExpr := IDs2Expr(fieldName, intIDs)
|
||||||
|
expectIDExpr := "pk in [ 1, 2, 3, 4, 5 ]"
|
||||||
|
assert.Equal(t, expectIDExpr, idExpr)
|
||||||
|
|
||||||
|
strExpr := IDs2Expr(fieldName, stringIDs)
|
||||||
|
expectStrExpr := "pk in [ \"a\", \"b\", \"c\" ]"
|
||||||
|
assert.Equal(t, expectStrExpr, strExpr)
|
||||||
|
}
|
||||||
|
|
|
@ -3,11 +3,13 @@ package proxy
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/samber/lo"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
@ -25,6 +27,7 @@ import (
|
||||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||||
|
@ -34,6 +37,12 @@ import (
|
||||||
const (
|
const (
|
||||||
SearchTaskName = "SearchTask"
|
SearchTaskName = "SearchTask"
|
||||||
SearchLevelKey = "level"
|
SearchLevelKey = "level"
|
||||||
|
|
||||||
|
// requeryThreshold is the estimated threshold for the size of the search results.
|
||||||
|
// If the number of estimated search results exceeds this threshold,
|
||||||
|
// a second query request will be initiated to retrieve output fields data.
|
||||||
|
// In this case, the first search will not return any output field from QueryNodes.
|
||||||
|
requeryThreshold = 0.5 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
type searchTask struct {
|
type searchTask struct {
|
||||||
|
@ -43,11 +52,12 @@ type searchTask struct {
|
||||||
|
|
||||||
result *milvuspb.SearchResults
|
result *milvuspb.SearchResults
|
||||||
request *milvuspb.SearchRequest
|
request *milvuspb.SearchRequest
|
||||||
qc types.QueryCoord
|
|
||||||
tr *timerecord.TimeRecorder
|
tr *timerecord.TimeRecorder
|
||||||
collectionName string
|
collectionName string
|
||||||
channelNum int32
|
channelNum int32
|
||||||
schema *schemapb.CollectionSchema
|
schema *schemapb.CollectionSchema
|
||||||
|
requery bool
|
||||||
|
|
||||||
offset int64
|
offset int64
|
||||||
resultBuf chan *internalpb.SearchResults
|
resultBuf chan *internalpb.SearchResults
|
||||||
|
@ -55,6 +65,9 @@ type searchTask struct {
|
||||||
|
|
||||||
searchShardPolicy pickShardPolicy
|
searchShardPolicy pickShardPolicy
|
||||||
shardMgr *shardClientMgr
|
shardMgr *shardClientMgr
|
||||||
|
|
||||||
|
qc types.QueryCoord
|
||||||
|
node types.ProxyComponent
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||||
|
@ -164,11 +177,7 @@ func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string)
|
||||||
hitField := false
|
hitField := false
|
||||||
for _, field := range schema.GetFields() {
|
for _, field := range schema.GetFields() {
|
||||||
if field.Name == name {
|
if field.Name == name {
|
||||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
|
||||||
return nil, errors.New("search doesn't support vector field as output_fields")
|
|
||||||
}
|
|
||||||
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
|
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
|
||||||
|
|
||||||
hitField = true
|
hitField = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -255,6 +264,24 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||||
|
|
||||||
|
// Manually update nq if not set.
|
||||||
|
nq, err := getNq(t.request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Check if nq is valid:
|
||||||
|
// https://milvus.io/docs/limitations.md
|
||||||
|
if err := validateLimit(nq); err != nil {
|
||||||
|
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||||
|
}
|
||||||
|
t.SearchRequest.Nq = nq
|
||||||
|
|
||||||
|
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||||
|
|
||||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -278,17 +305,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||||
|
|
||||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
|
||||||
plan.OutputFieldIds = outputFieldIDs
|
plan.OutputFieldIds = outputFieldIDs
|
||||||
|
|
||||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||||
|
|
||||||
|
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if estimateSize >= requeryThreshold {
|
||||||
|
t.requery = true
|
||||||
|
plan.OutputFieldIds = nil
|
||||||
|
}
|
||||||
|
|
||||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -319,17 +350,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||||
|
|
||||||
t.SearchRequest.Dsl = t.request.Dsl
|
t.SearchRequest.Dsl = t.request.Dsl
|
||||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||||
// Manually update nq if not set.
|
|
||||||
nq, err := getNq(t.request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Check if nq is valid:
|
|
||||||
// https://milvus.io/docs/limitations.md
|
|
||||||
if err := validateLimit(nq); err != nil {
|
|
||||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
|
||||||
}
|
|
||||||
t.SearchRequest.Nq = nq
|
|
||||||
|
|
||||||
log.Ctx(ctx).Debug("search PreExecute done.",
|
log.Ctx(ctx).Debug("search PreExecute done.",
|
||||||
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
|
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
|
||||||
|
@ -435,6 +455,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||||
t.result.CollectionName = t.collectionName
|
t.result.CollectionName = t.collectionName
|
||||||
t.fillInFieldInfo()
|
t.fillInFieldInfo()
|
||||||
|
|
||||||
|
if t.requery {
|
||||||
|
err = t.Requery()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Ctx(ctx).Debug("Search post execute done",
|
log.Ctx(ctx).Debug("Search post execute done",
|
||||||
zap.Int64("collection", t.GetCollectionID()),
|
zap.Int64("collection", t.GetCollectionID()),
|
||||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()))
|
zap.Int64s("partitionIDs", t.GetPartitionIDs()))
|
||||||
|
@ -480,6 +507,93 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
|
||||||
|
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||||
|
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
|
||||||
|
})
|
||||||
|
// Currently, we get vectors by requery. Once we support getting vectors from search,
|
||||||
|
// searches with small result size could no longer need requery.
|
||||||
|
if len(vectorOutputFields) > 0 {
|
||||||
|
return math.MaxInt64, nil
|
||||||
|
}
|
||||||
|
// If no vector field as output, no need to requery.
|
||||||
|
return 0, nil
|
||||||
|
|
||||||
|
//outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||||
|
// return lo.Contains(t.request.GetOutputFields(), field.GetName())
|
||||||
|
//})
|
||||||
|
//sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields})
|
||||||
|
//if err != nil {
|
||||||
|
// return 0, err
|
||||||
|
//}
|
||||||
|
//return int64(sizePerRecord) * nq * topK, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *searchTask) Requery() error {
|
||||||
|
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ids := t.result.GetResults().GetIds()
|
||||||
|
expr := IDs2Expr(pkField.GetName(), ids)
|
||||||
|
|
||||||
|
queryReq := &milvuspb.QueryRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Retrieve,
|
||||||
|
},
|
||||||
|
CollectionName: t.request.GetCollectionName(),
|
||||||
|
Expr: expr,
|
||||||
|
OutputFields: t.request.GetOutputFields(),
|
||||||
|
PartitionNames: t.request.GetPartitionNames(),
|
||||||
|
TravelTimestamp: t.request.GetTravelTimestamp(),
|
||||||
|
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||||
|
QueryParams: t.request.GetSearchParams(),
|
||||||
|
}
|
||||||
|
queryResult, err := t.node.Query(t.ctx, queryReq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
return merr.Error(queryResult.GetStatus())
|
||||||
|
}
|
||||||
|
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
|
||||||
|
// We should reorganize query results to keep the order of original queried ids. For example:
|
||||||
|
// ===========================================
|
||||||
|
// 3 2 5 4 1 (query ids)
|
||||||
|
// ||
|
||||||
|
// || (query)
|
||||||
|
// \/
|
||||||
|
// 4 3 5 1 2 (result ids)
|
||||||
|
// v4 v3 v5 v1 v2 (result vectors)
|
||||||
|
// ||
|
||||||
|
// || (reorganize)
|
||||||
|
// \/
|
||||||
|
// 3 2 5 4 1 (result ids)
|
||||||
|
// v3 v2 v5 v4 v1 (result vectors)
|
||||||
|
// ===========================================
|
||||||
|
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
offsets := make(map[any]int)
|
||||||
|
for i := 0; i < typeutil.GetDataSize(pkFieldData); i++ {
|
||||||
|
pk := typeutil.GetData(pkFieldData, i)
|
||||||
|
offsets[pk] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
|
||||||
|
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
|
||||||
|
id := typeutil.GetPK(ids, int64(i))
|
||||||
|
if _, ok := offsets[id]; !ok {
|
||||||
|
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||||
|
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID())
|
||||||
|
}
|
||||||
|
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||||
t.result = &milvuspb.SearchResults{
|
t.result = &milvuspb.SearchResults{
|
||||||
Status: &commonpb.Status{
|
Status: &commonpb.Status{
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/mocks"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
@ -268,7 +269,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||||
|
|
||||||
// contain vector field
|
// contain vector field
|
||||||
task.request.OutputFields = []string{testFloatVecField}
|
task.request.OutputFields = []string{testFloatVecField}
|
||||||
assert.Error(t, task.PreExecute(ctx))
|
assert.NoError(t, task.PreExecute(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1959,3 +1960,287 @@ func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||||
}
|
}
|
||||||
return &result
|
return &result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSearchTask_Requery(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
const (
|
||||||
|
dim = 128
|
||||||
|
rows = 5
|
||||||
|
collection = "test-requery"
|
||||||
|
|
||||||
|
pkField = "pk"
|
||||||
|
vecField = "vec"
|
||||||
|
)
|
||||||
|
|
||||||
|
ids := make([]int64, rows)
|
||||||
|
for i := range ids {
|
||||||
|
ids[i] = int64(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Test normal", func(t *testing.T) {
|
||||||
|
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||||
|
node := mocks.NewProxy(t)
|
||||||
|
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||||
|
Return(&milvuspb.QueryResults{
|
||||||
|
FieldsData: []*schemapb.FieldData{{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
FieldName: pkField,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newFloatVectorFieldData(vecField, rows, dim),
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
resultIDs := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
qt := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Search,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request: &milvuspb.SearchRequest{},
|
||||||
|
result: &milvuspb.SearchResults{
|
||||||
|
Results: &schemapb.SearchResultData{
|
||||||
|
Ids: resultIDs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
schema: schema,
|
||||||
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := qt.Requery()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test no primary key", func(t *testing.T) {
|
||||||
|
schema := &schemapb.CollectionSchema{}
|
||||||
|
node := mocks.NewProxy(t)
|
||||||
|
|
||||||
|
qt := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Search,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request: &milvuspb.SearchRequest{},
|
||||||
|
schema: schema,
|
||||||
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := qt.Requery()
|
||||||
|
t.Logf("err = %s", err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test requery failed 1", func(t *testing.T) {
|
||||||
|
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||||
|
node := mocks.NewProxy(t)
|
||||||
|
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||||
|
Return(nil, fmt.Errorf("mock err 1"))
|
||||||
|
|
||||||
|
qt := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Search,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request: &milvuspb.SearchRequest{},
|
||||||
|
schema: schema,
|
||||||
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := qt.Requery()
|
||||||
|
t.Logf("err = %s", err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test requery failed 2", func(t *testing.T) {
|
||||||
|
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||||
|
node := mocks.NewProxy(t)
|
||||||
|
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||||
|
Return(&milvuspb.QueryResults{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "mock err 2",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
qt := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Search,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request: &milvuspb.SearchRequest{},
|
||||||
|
schema: schema,
|
||||||
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := qt.Requery()
|
||||||
|
t.Logf("err = %s", err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test get pk filed data failed", func(t *testing.T) {
|
||||||
|
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||||
|
node := mocks.NewProxy(t)
|
||||||
|
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||||
|
Return(&milvuspb.QueryResults{
|
||||||
|
FieldsData: []*schemapb.FieldData{},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
qt := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Search,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request: &milvuspb.SearchRequest{},
|
||||||
|
schema: schema,
|
||||||
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := qt.Requery()
|
||||||
|
t.Logf("err = %s", err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test incomplete query result", func(t *testing.T) {
|
||||||
|
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||||
|
node := mocks.NewProxy(t)
|
||||||
|
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||||
|
Return(&milvuspb.QueryResults{
|
||||||
|
FieldsData: []*schemapb.FieldData{{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
FieldName: pkField,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{
|
||||||
|
Data: ids[:len(ids)-1],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newFloatVectorFieldData(vecField, rows, dim),
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
resultIDs := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
qt := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Search,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request: &milvuspb.SearchRequest{},
|
||||||
|
result: &milvuspb.SearchResults{
|
||||||
|
Results: &schemapb.SearchResultData{
|
||||||
|
Ids: resultIDs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
schema: schema,
|
||||||
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := qt.Requery()
|
||||||
|
t.Logf("err = %s", err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test postExecute with requery failed", func(t *testing.T) {
|
||||||
|
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||||
|
node := mocks.NewProxy(t)
|
||||||
|
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||||
|
Return(nil, fmt.Errorf("mock err 1"))
|
||||||
|
|
||||||
|
resultIDs := &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
qt := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Search,
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request: &milvuspb.SearchRequest{},
|
||||||
|
result: &milvuspb.SearchResults{
|
||||||
|
Results: &schemapb.SearchResultData{
|
||||||
|
Ids: resultIDs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requery: true,
|
||||||
|
schema: schema,
|
||||||
|
resultBuf: make(chan *internalpb.SearchResults, 10),
|
||||||
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
|
node: node,
|
||||||
|
}
|
||||||
|
scores := make([]float32, rows)
|
||||||
|
for i := range scores {
|
||||||
|
scores[i] = float32(i)
|
||||||
|
}
|
||||||
|
partialResultData := &schemapb.SearchResultData{
|
||||||
|
Ids: resultIDs,
|
||||||
|
Scores: scores,
|
||||||
|
}
|
||||||
|
bytes, err := proto.Marshal(partialResultData)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
qt.resultBuf <- &internalpb.SearchResults{
|
||||||
|
SlicedBlob: bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = qt.PostExecute(ctx)
|
||||||
|
t.Logf("err = %s", err)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -282,6 +282,13 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool {
|
||||||
return fieldInfo.IndexInfo != nil && fieldInfo.IndexInfo.EnableIndex
|
return fieldInfo.IndexInfo != nil && fieldInfo.IndexInfo.EnableIndex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *LocalSegment) HasRawData(fieldID int64) bool {
|
||||||
|
s.mut.RLock()
|
||||||
|
defer s.mut.RUnlock()
|
||||||
|
ret := C.HasRawData(s.ptr, C.int64_t(fieldID))
|
||||||
|
return bool(ret)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *LocalSegment) Indexes() []*IndexedFieldInfo {
|
func (s *LocalSegment) Indexes() []*IndexedFieldInfo {
|
||||||
var result []*IndexedFieldInfo
|
var result []*IndexedFieldInfo
|
||||||
s.fieldIndexes.Range(func(key int64, value *IndexedFieldInfo) bool {
|
s.fieldIndexes.Range(func(key int64, value *IndexedFieldInfo) bool {
|
||||||
|
@ -463,10 +470,18 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, fieldData := range result.FieldsData {
|
for _, fieldData := range result.FieldsData {
|
||||||
// If the vector field doesn't have indexed. Vector data is in memory for
|
// If the field is not vector field, no need to download data from remote.
|
||||||
// brute force search. No need to download data from remote.
|
if !typeutil.IsVectorType(fieldData.GetType()) {
|
||||||
if fieldData.GetType() != schemapb.DataType_FloatVector && fieldData.GetType() != schemapb.DataType_BinaryVector ||
|
continue
|
||||||
!s.ExistIndex(fieldData.FieldId) {
|
}
|
||||||
|
// If the vector field doesn't have indexed, vector data is in memory
|
||||||
|
// for brute force search, no need to download data from remote.
|
||||||
|
if !s.ExistIndex(fieldData.FieldId) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// If the index has raw data, vector data could be obtained from index,
|
||||||
|
// no need to download data from remote.
|
||||||
|
if s.HasRawData(fieldData.FieldId) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -117,6 +117,13 @@ func (suite *SegmentSuite) TestDelete() {
|
||||||
suite.Equal(rowNum, suite.growing.InsertCount())
|
suite.Equal(rowNum, suite.growing.InsertCount())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *SegmentSuite) TestHasRawData() {
|
||||||
|
has := suite.growing.HasRawData(simpleFloatVecField.id)
|
||||||
|
suite.True(has)
|
||||||
|
has = suite.sealed.HasRawData(simpleFloatVecField.id)
|
||||||
|
suite.True(has)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSegment(t *testing.T) {
|
func TestSegment(t *testing.T) {
|
||||||
suite.Run(t, new(SegmentSuite))
|
suite.Run(t, new(SegmentSuite))
|
||||||
}
|
}
|
||||||
|
|
|
@ -735,6 +735,30 @@ func GetSizeOfIDs(data *schemapb.IDs) int {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetDataSize(fieldData *schemapb.FieldData) int {
|
||||||
|
switch fieldData.GetType() {
|
||||||
|
case schemapb.DataType_Bool:
|
||||||
|
return len(fieldData.GetScalars().GetBoolData().GetData())
|
||||||
|
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||||
|
return len(fieldData.GetScalars().GetIntData().GetData())
|
||||||
|
case schemapb.DataType_Int64:
|
||||||
|
return len(fieldData.GetScalars().GetLongData().GetData())
|
||||||
|
case schemapb.DataType_Float:
|
||||||
|
return len(fieldData.GetScalars().GetFloatData().GetData())
|
||||||
|
case schemapb.DataType_Double:
|
||||||
|
return len(fieldData.GetScalars().GetDoubleData().GetData())
|
||||||
|
case schemapb.DataType_String:
|
||||||
|
return len(fieldData.GetScalars().GetStringData().GetData())
|
||||||
|
case schemapb.DataType_VarChar:
|
||||||
|
return len(fieldData.GetScalars().GetStringData().GetData())
|
||||||
|
case schemapb.DataType_FloatVector:
|
||||||
|
return len(fieldData.GetVectors().GetFloatVector().GetData())
|
||||||
|
case schemapb.DataType_BinaryVector:
|
||||||
|
return len(fieldData.GetVectors().GetBinaryVector())
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func IsPrimaryFieldType(dataType schemapb.DataType) bool {
|
func IsPrimaryFieldType(dataType schemapb.DataType) bool {
|
||||||
if dataType == schemapb.DataType_Int64 || dataType == schemapb.DataType_VarChar {
|
if dataType == schemapb.DataType_Int64 || dataType == schemapb.DataType_VarChar {
|
||||||
return true
|
return true
|
||||||
|
@ -756,6 +780,33 @@ func GetPK(data *schemapb.IDs, idx int64) interface{} {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetData(field *schemapb.FieldData, idx int) interface{} {
|
||||||
|
switch field.GetType() {
|
||||||
|
case schemapb.DataType_Bool:
|
||||||
|
return field.GetScalars().GetBoolData().GetData()[idx]
|
||||||
|
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||||
|
return field.GetScalars().GetIntData().GetData()[idx]
|
||||||
|
case schemapb.DataType_Int64:
|
||||||
|
return field.GetScalars().GetLongData().GetData()[idx]
|
||||||
|
case schemapb.DataType_Float:
|
||||||
|
return field.GetScalars().GetFloatData().GetData()[idx]
|
||||||
|
case schemapb.DataType_Double:
|
||||||
|
return field.GetScalars().GetDoubleData().GetData()[idx]
|
||||||
|
case schemapb.DataType_String:
|
||||||
|
return field.GetScalars().GetStringData().GetData()[idx]
|
||||||
|
case schemapb.DataType_VarChar:
|
||||||
|
return field.GetScalars().GetStringData().GetData()[idx]
|
||||||
|
case schemapb.DataType_FloatVector:
|
||||||
|
dim := int(field.GetVectors().GetDim())
|
||||||
|
return field.GetVectors().GetFloatVector().GetData()[idx*dim : (idx+1)*dim]
|
||||||
|
case schemapb.DataType_BinaryVector:
|
||||||
|
dim := int(field.GetVectors().GetDim())
|
||||||
|
dataBytes := dim / 8
|
||||||
|
return field.GetVectors().GetBinaryVector()[idx*dataBytes : (idx+1)*dataBytes]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func AppendPKs(pks *schemapb.IDs, pk interface{}) {
|
func AppendPKs(pks *schemapb.IDs, pk interface{}) {
|
||||||
switch realPK := pk.(type) {
|
switch realPK := pk.(type) {
|
||||||
case int64:
|
case int64:
|
||||||
|
|
|
@ -470,6 +470,21 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType,
|
||||||
},
|
},
|
||||||
FieldId: fieldID,
|
FieldId: fieldID,
|
||||||
}
|
}
|
||||||
|
case schemapb.DataType_String:
|
||||||
|
fieldData = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_String,
|
||||||
|
FieldName: fieldName,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: fieldValue.([]string),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: fieldID,
|
||||||
|
}
|
||||||
case schemapb.DataType_VarChar:
|
case schemapb.DataType_VarChar:
|
||||||
fieldData = &schemapb.FieldData{
|
fieldData = &schemapb.FieldData{
|
||||||
Type: schemapb.DataType_VarChar,
|
Type: schemapb.DataType_VarChar,
|
||||||
|
@ -990,3 +1005,94 @@ func TestCalcColumnSize(t *testing.T) {
|
||||||
assert.Equal(t, expected, size, field.GetName())
|
assert.Equal(t, expected, size, field.GetName())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetDataAndGetDataSize(t *testing.T) {
|
||||||
|
const (
|
||||||
|
Dim = 8
|
||||||
|
fieldName = "filed-0"
|
||||||
|
fieldID = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
BoolArray := []bool{true, false}
|
||||||
|
Int8Array := []int8{1, 2}
|
||||||
|
Int16Array := []int16{3, 4}
|
||||||
|
Int32Array := []int32{5, 6}
|
||||||
|
Int64Array := []int64{11, 22}
|
||||||
|
FloatArray := []float32{1.0, 2.0}
|
||||||
|
DoubleArray := []float64{11.0, 22.0}
|
||||||
|
VarCharArray := []string{"a", "b"}
|
||||||
|
StringArray := []string{"c", "d"}
|
||||||
|
BinaryVector := []byte{0x12, 0x34}
|
||||||
|
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
|
||||||
|
|
||||||
|
boolData := genFieldData(fieldName, fieldID, schemapb.DataType_Bool, BoolArray, 1)
|
||||||
|
int8Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int8, Int8Array, 1)
|
||||||
|
int16Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int16, Int16Array, 1)
|
||||||
|
int32Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int32, Int32Array, 1)
|
||||||
|
int64Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int64, Int64Array, 1)
|
||||||
|
floatData := genFieldData(fieldName, fieldID, schemapb.DataType_Float, FloatArray, 1)
|
||||||
|
doubleData := genFieldData(fieldName, fieldID, schemapb.DataType_Double, DoubleArray, 1)
|
||||||
|
varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1)
|
||||||
|
stringData := genFieldData(fieldName, fieldID, schemapb.DataType_String, StringArray, 1)
|
||||||
|
binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim)
|
||||||
|
floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim)
|
||||||
|
invalidData := &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_None,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("test GetDataSize", func(t *testing.T) {
|
||||||
|
boolDataRes := GetDataSize(boolData)
|
||||||
|
int8DataRes := GetDataSize(int8Data)
|
||||||
|
int16DataRes := GetDataSize(int16Data)
|
||||||
|
int32DataRes := GetDataSize(int32Data)
|
||||||
|
int64DataRes := GetDataSize(int64Data)
|
||||||
|
floatDataRes := GetDataSize(floatData)
|
||||||
|
doubleDataRes := GetDataSize(doubleData)
|
||||||
|
varCharDataRes := GetDataSize(varCharData)
|
||||||
|
stringDataRes := GetDataSize(stringData)
|
||||||
|
binVecDataRes := GetDataSize(binVecData)
|
||||||
|
floatVecDataRes := GetDataSize(floatVecData)
|
||||||
|
invalidDataRes := GetDataSize(invalidData)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, boolDataRes)
|
||||||
|
assert.Equal(t, 2, int8DataRes)
|
||||||
|
assert.Equal(t, 2, int16DataRes)
|
||||||
|
assert.Equal(t, 2, int32DataRes)
|
||||||
|
assert.Equal(t, 2, int64DataRes)
|
||||||
|
assert.Equal(t, 2, floatDataRes)
|
||||||
|
assert.Equal(t, 2, doubleDataRes)
|
||||||
|
assert.Equal(t, 2, varCharDataRes)
|
||||||
|
assert.Equal(t, 2, stringDataRes)
|
||||||
|
assert.Equal(t, 2*Dim/8, binVecDataRes)
|
||||||
|
assert.Equal(t, 2*Dim, floatVecDataRes)
|
||||||
|
assert.Equal(t, 0, invalidDataRes)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test GetData", func(t *testing.T) {
|
||||||
|
boolDataRes := GetData(boolData, 0)
|
||||||
|
int8DataRes := GetData(int8Data, 0)
|
||||||
|
int16DataRes := GetData(int16Data, 0)
|
||||||
|
int32DataRes := GetData(int32Data, 0)
|
||||||
|
int64DataRes := GetData(int64Data, 0)
|
||||||
|
floatDataRes := GetData(floatData, 0)
|
||||||
|
doubleDataRes := GetData(doubleData, 0)
|
||||||
|
varCharDataRes := GetData(varCharData, 0)
|
||||||
|
stringDataRes := GetData(stringData, 0)
|
||||||
|
binVecDataRes := GetData(binVecData, 0)
|
||||||
|
floatVecDataRes := GetData(floatVecData, 0)
|
||||||
|
invalidDataRes := GetData(invalidData, 0)
|
||||||
|
|
||||||
|
assert.Equal(t, BoolArray[0], boolDataRes)
|
||||||
|
assert.Equal(t, int32(Int8Array[0]), int8DataRes)
|
||||||
|
assert.Equal(t, int32(Int16Array[0]), int16DataRes)
|
||||||
|
assert.Equal(t, Int32Array[0], int32DataRes)
|
||||||
|
assert.Equal(t, Int64Array[0], int64DataRes)
|
||||||
|
assert.Equal(t, FloatArray[0], floatDataRes)
|
||||||
|
assert.Equal(t, DoubleArray[0], doubleDataRes)
|
||||||
|
assert.Equal(t, VarCharArray[0], varCharDataRes)
|
||||||
|
assert.Equal(t, StringArray[0], stringDataRes)
|
||||||
|
assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes)
|
||||||
|
assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes)
|
||||||
|
assert.Nil(t, invalidDataRes)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -51,72 +51,24 @@ const (
|
||||||
// 5, load
|
// 5, load
|
||||||
// 6, search
|
// 6, search
|
||||||
func TestBulkInsert(t *testing.T) {
|
func TestBulkInsert(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||||
c, err := StartMiniCluster(ctx)
|
c, err := StartMiniCluster(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = c.Start()
|
err = c.Start()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer c.Stop()
|
defer func() {
|
||||||
|
err = c.Stop()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
prefix := "TestBulkInsert"
|
prefix := "TestBulkInsert"
|
||||||
dbName := ""
|
dbName := ""
|
||||||
collectionName := prefix + funcutil.GenRandomStr()
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
int64Field := "int64"
|
floatVecField := floatVecField
|
||||||
floatVecField := "embeddings"
|
|
||||||
scalarField := "image_path"
|
|
||||||
dim := 128
|
dim := 128
|
||||||
|
|
||||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
schema := constructSchema(collectionName, dim, true)
|
||||||
pk := &schemapb.FieldSchema{
|
|
||||||
Name: int64Field,
|
|
||||||
IsPrimaryKey: true,
|
|
||||||
Description: "pk",
|
|
||||||
DataType: schemapb.DataType_Int64,
|
|
||||||
TypeParams: nil,
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: true,
|
|
||||||
}
|
|
||||||
fVec := &schemapb.FieldSchema{
|
|
||||||
Name: floatVecField,
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_FloatVector,
|
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: false,
|
|
||||||
}
|
|
||||||
scalar := &schemapb.FieldSchema{
|
|
||||||
Name: scalarField,
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_VarChar,
|
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: "max_length",
|
|
||||||
Value: "65535",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: false,
|
|
||||||
}
|
|
||||||
return &schemapb.CollectionSchema{
|
|
||||||
Name: collectionName,
|
|
||||||
Description: "",
|
|
||||||
AutoID: false,
|
|
||||||
Fields: []*schemapb.FieldSchema{
|
|
||||||
pk,
|
|
||||||
fVec,
|
|
||||||
scalar,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
schema := constructCollectionSchema()
|
|
||||||
marshaledSchema, err := proto.Marshal(schema)
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -207,28 +159,7 @@ func TestBulkInsert(t *testing.T) {
|
||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
FieldName: floatVecField,
|
FieldName: floatVecField,
|
||||||
IndexName: "_default",
|
IndexName: "_default",
|
||||||
ExtraParams: []*commonpb.KeyValuePair{
|
ExtraParams: constructIndexParam(dim, IndexHNSW, distance.L2),
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: distance.L2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "index_type",
|
|
||||||
Value: "HNSW",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "M",
|
|
||||||
Value: "64",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "efConstruction",
|
|
||||||
Value: "512",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||||
|
@ -246,30 +177,17 @@ func TestBulkInsert(t *testing.T) {
|
||||||
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
||||||
}
|
}
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
||||||
for {
|
waitingForLoad(ctx, c, collectionName)
|
||||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
|
||||||
CollectionName: collectionName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic("GetLoadingProgress fail")
|
|
||||||
}
|
|
||||||
if loadProgress.GetProgress() == 100 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
// search
|
// search
|
||||||
expr := fmt.Sprintf("%s > 0", "int64")
|
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||||
nq := 10
|
nq := 10
|
||||||
topk := 10
|
topk := 10
|
||||||
roundDecimal := -1
|
roundDecimal := -1
|
||||||
nprobe := 10
|
|
||||||
params := make(map[string]int)
|
|
||||||
params["nprobe"] = nprobe
|
|
||||||
|
|
||||||
|
params := getSearchParams(IndexHNSW, distance.L2)
|
||||||
searchReq := constructSearchRequest("", collectionName, expr,
|
searchReq := constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, err := c.proxy.Search(ctx, searchReq)
|
searchResult, err := c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,366 @@
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestGetVectorSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
cluster *MiniCluster
|
||||||
|
|
||||||
|
// test params
|
||||||
|
nq int
|
||||||
|
topK int
|
||||||
|
indexType string
|
||||||
|
metricType string
|
||||||
|
pkType schemapb.DataType
|
||||||
|
vecType schemapb.DataType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) SetupTest() {
|
||||||
|
suite.ctx, suite.cancel = context.WithTimeout(context.Background(), time.Second*600)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
suite.cluster, err = StartMiniCluster(suite.ctx)
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
err = suite.cluster.Start()
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) run() {
|
||||||
|
collection := fmt.Sprintf("TestGetVector_%d_%d_%s_%s_%s",
|
||||||
|
suite.nq, suite.topK, suite.indexType, suite.metricType, funcutil.GenRandomStr())
|
||||||
|
|
||||||
|
const (
|
||||||
|
NB = 10000
|
||||||
|
dim = 128
|
||||||
|
)
|
||||||
|
|
||||||
|
pkFieldName := "pkField"
|
||||||
|
vecFieldName := "vecField"
|
||||||
|
pk := &schemapb.FieldSchema{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: pkFieldName,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Description: "",
|
||||||
|
DataType: suite.pkType,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: "max_length",
|
||||||
|
Value: "100",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IndexParams: nil,
|
||||||
|
AutoID: false,
|
||||||
|
}
|
||||||
|
fVec := &schemapb.FieldSchema{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: vecFieldName,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
Description: "",
|
||||||
|
DataType: suite.vecType,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.DimKey,
|
||||||
|
Value: fmt.Sprintf("%d", dim),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IndexParams: nil,
|
||||||
|
}
|
||||||
|
schema := constructSchema(collection, dim, false, pk, fVec)
|
||||||
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
|
||||||
|
createCollectionStatus, err := suite.cluster.proxy.CreateCollection(suite.ctx, &milvuspb.CreateCollectionRequest{
|
||||||
|
CollectionName: collection,
|
||||||
|
Schema: marshaledSchema,
|
||||||
|
ShardsNum: 2,
|
||||||
|
})
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
|
||||||
|
fieldsData := make([]*schemapb.FieldData, 0)
|
||||||
|
if suite.pkType == schemapb.DataType_Int64 {
|
||||||
|
fieldsData = append(fieldsData, newInt64FieldData(pkFieldName, NB))
|
||||||
|
} else {
|
||||||
|
fieldsData = append(fieldsData, newStringFieldData(pkFieldName, NB))
|
||||||
|
}
|
||||||
|
var vecFieldData *schemapb.FieldData
|
||||||
|
if suite.vecType == schemapb.DataType_FloatVector {
|
||||||
|
vecFieldData = newFloatVectorFieldData(vecFieldName, NB, dim)
|
||||||
|
} else {
|
||||||
|
vecFieldData = newBinaryVectorFieldData(vecFieldName, NB, dim)
|
||||||
|
}
|
||||||
|
fieldsData = append(fieldsData, vecFieldData)
|
||||||
|
hashKeys := generateHashKeys(NB)
|
||||||
|
_, err = suite.cluster.proxy.Insert(suite.ctx, &milvuspb.InsertRequest{
|
||||||
|
CollectionName: collection,
|
||||||
|
FieldsData: fieldsData,
|
||||||
|
HashKeys: hashKeys,
|
||||||
|
NumRows: uint32(NB),
|
||||||
|
})
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
|
||||||
|
// flush
|
||||||
|
flushResp, err := suite.cluster.proxy.Flush(suite.ctx, &milvuspb.FlushRequest{
|
||||||
|
CollectionNames: []string{collection},
|
||||||
|
})
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
segmentIDs, has := flushResp.GetCollSegIDs()[collection]
|
||||||
|
ids := segmentIDs.GetData()
|
||||||
|
suite.Require().NotEmpty(segmentIDs)
|
||||||
|
suite.Require().True(has)
|
||||||
|
|
||||||
|
segments, err := suite.cluster.metaWatcher.ShowSegments()
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.Require().NotEmpty(segments)
|
||||||
|
|
||||||
|
waitingForFlush(suite.ctx, suite.cluster, ids)
|
||||||
|
|
||||||
|
// create index
|
||||||
|
_, err = suite.cluster.proxy.CreateIndex(suite.ctx, &milvuspb.CreateIndexRequest{
|
||||||
|
CollectionName: collection,
|
||||||
|
FieldName: vecFieldName,
|
||||||
|
IndexName: "_default",
|
||||||
|
ExtraParams: constructIndexParam(dim, suite.indexType, suite.metricType),
|
||||||
|
})
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
|
||||||
|
// load
|
||||||
|
_, err = suite.cluster.proxy.LoadCollection(suite.ctx, &milvuspb.LoadCollectionRequest{
|
||||||
|
CollectionName: collection,
|
||||||
|
})
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
waitingForLoad(suite.ctx, suite.cluster, collection)
|
||||||
|
|
||||||
|
// search
|
||||||
|
nq := suite.nq
|
||||||
|
topk := suite.topK
|
||||||
|
|
||||||
|
outputFields := []string{vecFieldName}
|
||||||
|
params := getSearchParams(suite.indexType, suite.metricType)
|
||||||
|
searchReq := constructSearchRequest("", collection, "",
|
||||||
|
vecFieldName, suite.vecType, outputFields, suite.metricType, params, nq, dim, topk, -1)
|
||||||
|
|
||||||
|
searchResp, err := suite.cluster.proxy.Search(suite.ctx, searchReq)
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
|
||||||
|
result := searchResp.GetResults()
|
||||||
|
if suite.pkType == schemapb.DataType_Int64 {
|
||||||
|
suite.Require().Len(result.GetIds().GetIntId().GetData(), nq*topk)
|
||||||
|
} else {
|
||||||
|
suite.Require().Len(result.GetIds().GetStrId().GetData(), nq*topk)
|
||||||
|
}
|
||||||
|
suite.Require().Len(result.GetScores(), nq*topk)
|
||||||
|
suite.Require().GreaterOrEqual(len(result.GetFieldsData()), 1)
|
||||||
|
var vecFieldIndex = -1
|
||||||
|
for i, fieldData := range result.GetFieldsData() {
|
||||||
|
if typeutil.IsVectorType(fieldData.GetType()) {
|
||||||
|
vecFieldIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
suite.Require().EqualValues(nq, result.GetNumQueries())
|
||||||
|
suite.Require().EqualValues(topk, result.GetTopK())
|
||||||
|
|
||||||
|
// check output vectors
|
||||||
|
if suite.vecType == schemapb.DataType_FloatVector {
|
||||||
|
suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData(), nq*topk*dim)
|
||||||
|
rawData := vecFieldData.GetVectors().GetFloatVector().GetData()
|
||||||
|
resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData()
|
||||||
|
if suite.pkType == schemapb.DataType_Int64 {
|
||||||
|
for i, id := range result.GetIds().GetIntId().GetData() {
|
||||||
|
expect := rawData[int(id)*dim : (int(id)+1)*dim]
|
||||||
|
actual := resData[i*dim : (i+1)*dim]
|
||||||
|
suite.Require().ElementsMatch(expect, actual)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i, idStr := range result.GetIds().GetStrId().GetData() {
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
expect := rawData[id*dim : (id+1)*dim]
|
||||||
|
actual := resData[i*dim : (i+1)*dim]
|
||||||
|
suite.Require().ElementsMatch(expect, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector(), nq*topk*dim/8)
|
||||||
|
rawData := vecFieldData.GetVectors().GetBinaryVector()
|
||||||
|
resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector()
|
||||||
|
if suite.pkType == schemapb.DataType_Int64 {
|
||||||
|
for i, id := range result.GetIds().GetIntId().GetData() {
|
||||||
|
dataBytes := dim / 8
|
||||||
|
for j := 0; j < dataBytes; j++ {
|
||||||
|
expect := rawData[int(id)*dataBytes+j]
|
||||||
|
actual := resData[i*dataBytes+j]
|
||||||
|
suite.Require().Equal(expect, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i, idStr := range result.GetIds().GetStrId().GetData() {
|
||||||
|
dataBytes := dim / 8
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
for j := 0; j < dataBytes; j++ {
|
||||||
|
expect := rawData[id*dataBytes+j]
|
||||||
|
actual := resData[i*dataBytes+j]
|
||||||
|
suite.Require().Equal(expect, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := suite.cluster.proxy.DropCollection(suite.ctx, &milvuspb.DropCollectionRequest{
|
||||||
|
CollectionName: collection,
|
||||||
|
})
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.Require().Equal(status.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_FLAT() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexFaissIDMap
|
||||||
|
suite.metricType = distance.L2
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexFaissIvfFlat
|
||||||
|
suite.metricType = distance.L2
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_IVF_PQ() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexFaissIvfPQ
|
||||||
|
suite.metricType = distance.L2
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexFaissIvfSQ8
|
||||||
|
suite.metricType = distance.L2
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_HNSW() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexHNSW
|
||||||
|
suite.metricType = distance.L2
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_IP() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexHNSW
|
||||||
|
suite.metricType = distance.IP
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_StringPK() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexHNSW
|
||||||
|
suite.metricType = distance.L2
|
||||||
|
suite.pkType = schemapb.DataType_VarChar
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_BinaryVector() {
|
||||||
|
suite.nq = 10
|
||||||
|
suite.topK = 10
|
||||||
|
suite.indexType = IndexFaissBinIvfFlat
|
||||||
|
suite.metricType = distance.JACCARD
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_BinaryVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
|
||||||
|
suite.nq = 10000
|
||||||
|
suite.topK = 200
|
||||||
|
suite.indexType = IndexHNSW
|
||||||
|
suite.metricType = distance.L2
|
||||||
|
suite.pkType = schemapb.DataType_Int64
|
||||||
|
suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
suite.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
//func (suite *TestGetVectorSuite) TestGetVector_DISKANN() {
|
||||||
|
// suite.nq = 10
|
||||||
|
// suite.topK = 10
|
||||||
|
// suite.indexType = IndexDISKANN
|
||||||
|
// suite.metricType = distance.L2
|
||||||
|
// suite.pkType = schemapb.DataType_Int64
|
||||||
|
// suite.vecType = schemapb.DataType_FloatVector
|
||||||
|
// suite.run()
|
||||||
|
//}
|
||||||
|
|
||||||
|
func (suite *TestGetVectorSuite) TearDownTest() {
|
||||||
|
err := suite.cluster.Stop()
|
||||||
|
suite.Require().NoError(err)
|
||||||
|
suite.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetVector(t *testing.T) {
|
||||||
|
suite.Run(t, new(TestGetVectorSuite))
|
||||||
|
}
|
|
@ -17,17 +17,11 @@
|
||||||
package integration
|
package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
@ -42,59 +36,27 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHelloMilvus(t *testing.T) {
|
func TestHelloMilvus(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||||
|
defer cancel()
|
||||||
c, err := StartMiniCluster(ctx)
|
c, err := StartMiniCluster(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = c.Start()
|
err = c.Start()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer c.Stop()
|
defer func() {
|
||||||
|
err = c.Stop()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
prefix := "TestHelloMilvus"
|
const (
|
||||||
dbName := ""
|
dim = 128
|
||||||
collectionName := prefix + funcutil.GenRandomStr()
|
dbName = ""
|
||||||
int64Field := "int64"
|
rowNum = 3000
|
||||||
floatVecField := "fvec"
|
)
|
||||||
dim := 128
|
|
||||||
rowNum := 3000
|
|
||||||
|
|
||||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
collectionName := "TestHelloMilvus" + funcutil.GenRandomStr()
|
||||||
pk := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
schema := constructSchema(collectionName, dim, true)
|
||||||
Name: int64Field,
|
|
||||||
IsPrimaryKey: true,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_Int64,
|
|
||||||
TypeParams: nil,
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: true,
|
|
||||||
}
|
|
||||||
fVec := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
|
||||||
Name: floatVecField,
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_FloatVector,
|
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: false,
|
|
||||||
}
|
|
||||||
return &schemapb.CollectionSchema{
|
|
||||||
Name: collectionName,
|
|
||||||
Description: "",
|
|
||||||
AutoID: false,
|
|
||||||
Fields: []*schemapb.FieldSchema{
|
|
||||||
pk,
|
|
||||||
fVec,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
schema := constructCollectionSchema()
|
|
||||||
marshaledSchema, err := proto.Marshal(schema)
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -137,6 +99,7 @@ func TestHelloMilvus(t *testing.T) {
|
||||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||||
ids := segmentIDs.GetData()
|
ids := segmentIDs.GetData()
|
||||||
assert.NotEmpty(t, segmentIDs)
|
assert.NotEmpty(t, segmentIDs)
|
||||||
|
assert.True(t, has)
|
||||||
|
|
||||||
segments, err := c.metaWatcher.ShowSegments()
|
segments, err := c.metaWatcher.ShowSegments()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -144,52 +107,14 @@ func TestHelloMilvus(t *testing.T) {
|
||||||
for _, segment := range segments {
|
for _, segment := range segments {
|
||||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||||
}
|
}
|
||||||
|
waitingForFlush(ctx, c, ids)
|
||||||
if has && len(ids) > 0 {
|
|
||||||
flushed := func() bool {
|
|
||||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
|
||||||
SegmentIDs: ids,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
//panic(errors.New("GetFlushState failed"))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return resp.GetFlushed()
|
|
||||||
}
|
|
||||||
for !flushed() {
|
|
||||||
// respect context deadline/cancel
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
panic(errors.New("deadline exceeded"))
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create index
|
// create index
|
||||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
FieldName: floatVecField,
|
FieldName: floatVecField,
|
||||||
IndexName: "_default",
|
IndexName: "_default",
|
||||||
ExtraParams: []*commonpb.KeyValuePair{
|
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2),
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: distance.L2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "index_type",
|
|
||||||
Value: "IVF_FLAT",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "nlist",
|
|
||||||
Value: strconv.Itoa(10),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||||
|
@ -207,30 +132,17 @@ func TestHelloMilvus(t *testing.T) {
|
||||||
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
||||||
}
|
}
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
||||||
for {
|
waitingForLoad(ctx, c, collectionName)
|
||||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
|
||||||
CollectionName: collectionName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic("GetLoadingProgress fail")
|
|
||||||
}
|
|
||||||
if loadProgress.GetProgress() == 100 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
// search
|
// search
|
||||||
expr := fmt.Sprintf("%s > 0", "int64")
|
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||||
nq := 10
|
nq := 10
|
||||||
topk := 10
|
topk := 10
|
||||||
roundDecimal := -1
|
roundDecimal := -1
|
||||||
nprobe := 10
|
|
||||||
params := make(map[string]int)
|
|
||||||
params["nprobe"] = nprobe
|
|
||||||
|
|
||||||
|
params := getSearchParams(IndexFaissIvfFlat, distance.L2)
|
||||||
searchReq := constructSearchRequest("", collectionName, expr,
|
searchReq := constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, err := c.proxy.Search(ctx, searchReq)
|
searchResult, err := c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
@ -242,155 +154,3 @@ func TestHelloMilvus(t *testing.T) {
|
||||||
|
|
||||||
log.Info("TestHelloMilvus succeed")
|
log.Info("TestHelloMilvus succeed")
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
AnnsFieldKey = "anns_field"
|
|
||||||
TopKKey = "topk"
|
|
||||||
NQKey = "nq"
|
|
||||||
MetricTypeKey = "metric_type"
|
|
||||||
SearchParamsKey = "params"
|
|
||||||
RoundDecimalKey = "round_decimal"
|
|
||||||
OffsetKey = "offset"
|
|
||||||
LimitKey = "limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
func constructSearchRequest(
|
|
||||||
dbName, collectionName string,
|
|
||||||
expr string,
|
|
||||||
floatVecField string,
|
|
||||||
metricType string,
|
|
||||||
params map[string]int,
|
|
||||||
nq, dim, topk, roundDecimal int,
|
|
||||||
) *milvuspb.SearchRequest {
|
|
||||||
b, err := json.Marshal(params)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
plg := constructPlaceholderGroup(nq, dim)
|
|
||||||
plgBs, err := proto.Marshal(plg)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &milvuspb.SearchRequest{
|
|
||||||
Base: nil,
|
|
||||||
DbName: dbName,
|
|
||||||
CollectionName: collectionName,
|
|
||||||
PartitionNames: nil,
|
|
||||||
Dsl: expr,
|
|
||||||
PlaceholderGroup: plgBs,
|
|
||||||
DslType: commonpb.DslType_BoolExprV1,
|
|
||||||
OutputFields: nil,
|
|
||||||
SearchParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: metricType,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: string(b),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: AnnsFieldKey,
|
|
||||||
Value: floatVecField,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: TopKKey,
|
|
||||||
Value: strconv.Itoa(topk),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: RoundDecimalKey,
|
|
||||||
Value: strconv.Itoa(roundDecimal),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
TravelTimestamp: 0,
|
|
||||||
GuaranteeTimestamp: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func constructPlaceholderGroup(
|
|
||||||
nq, dim int,
|
|
||||||
) *commonpb.PlaceholderGroup {
|
|
||||||
values := make([][]byte, 0, nq)
|
|
||||||
for i := 0; i < nq; i++ {
|
|
||||||
bs := make([]byte, 0, dim*4)
|
|
||||||
for j := 0; j < dim; j++ {
|
|
||||||
var buffer bytes.Buffer
|
|
||||||
f := rand.Float32()
|
|
||||||
err := binary.Write(&buffer, common.Endian, f)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
bs = append(bs, buffer.Bytes()...)
|
|
||||||
}
|
|
||||||
values = append(values, bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &commonpb.PlaceholderGroup{
|
|
||||||
Placeholders: []*commonpb.PlaceholderValue{
|
|
||||||
{
|
|
||||||
Tag: "$0",
|
|
||||||
Type: commonpb.PlaceholderType_FloatVector,
|
|
||||||
Values: values,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
|
||||||
return &schemapb.FieldData{
|
|
||||||
Type: schemapb.DataType_FloatVector,
|
|
||||||
FieldName: fieldName,
|
|
||||||
Field: &schemapb.FieldData_Vectors{
|
|
||||||
Vectors: &schemapb.VectorField{
|
|
||||||
Dim: int64(dim),
|
|
||||||
Data: &schemapb.VectorField_FloatVector{
|
|
||||||
FloatVector: &schemapb.FloatArray{
|
|
||||||
Data: generateFloatVectors(numRows, dim),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newInt64PrimaryKey(fieldName string, numRows int) *schemapb.FieldData {
|
|
||||||
return &schemapb.FieldData{
|
|
||||||
Type: schemapb.DataType_Int64,
|
|
||||||
FieldName: fieldName,
|
|
||||||
Field: &schemapb.FieldData_Scalars{
|
|
||||||
Scalars: &schemapb.ScalarField{
|
|
||||||
Data: &schemapb.ScalarField_LongData{
|
|
||||||
LongData: &schemapb.LongArray{
|
|
||||||
Data: generateInt64Array(numRows),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateFloatVectors(numRows, dim int) []float32 {
|
|
||||||
total := numRows * dim
|
|
||||||
ret := make([]float32, 0, total)
|
|
||||||
for i := 0; i < total; i++ {
|
|
||||||
ret = append(ret, rand.Float32())
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateInt64Array(numRows int) []int64 {
|
|
||||||
ret := make([]int64, 0, numRows)
|
|
||||||
for i := 0; i < numRows; i++ {
|
|
||||||
ret = append(ret, int64(rand.Int()))
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateHashKeys(numRows int) []uint32 {
|
|
||||||
ret := make([]uint32, 0, numRows)
|
|
||||||
for i := 0; i < numRows; i++ {
|
|
||||||
ret = append(ret, rand.Uint32())
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
|
@ -19,16 +19,13 @@ package integration
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
@ -39,59 +36,24 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRangeSearchIP(t *testing.T) {
|
func TestRangeSearchIP(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||||
c, err := StartMiniCluster(ctx)
|
c, err := StartMiniCluster(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = c.Start()
|
err = c.Start()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer c.Stop()
|
defer func() {
|
||||||
|
err = c.Stop()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
prefix := "TestRangeSearchIP"
|
prefix := "TestRangeSearchIP"
|
||||||
dbName := ""
|
dbName := ""
|
||||||
collectionName := prefix + funcutil.GenRandomStr()
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
int64Field := "int64"
|
|
||||||
floatVecField := "fvec"
|
|
||||||
dim := 128
|
dim := 128
|
||||||
rowNum := 3000
|
rowNum := 3000
|
||||||
|
|
||||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
schema := constructSchema(collectionName, dim, true)
|
||||||
pk := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
|
||||||
Name: int64Field,
|
|
||||||
IsPrimaryKey: true,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_Int64,
|
|
||||||
TypeParams: nil,
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: true,
|
|
||||||
}
|
|
||||||
fVec := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
|
||||||
Name: floatVecField,
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_FloatVector,
|
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: false,
|
|
||||||
}
|
|
||||||
return &schemapb.CollectionSchema{
|
|
||||||
Name: collectionName,
|
|
||||||
Description: "",
|
|
||||||
AutoID: false,
|
|
||||||
Fields: []*schemapb.FieldSchema{
|
|
||||||
pk,
|
|
||||||
fVec,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
schema := constructCollectionSchema()
|
|
||||||
marshaledSchema, err := proto.Marshal(schema)
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -133,6 +95,7 @@ func TestRangeSearchIP(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||||
|
assert.True(t, has)
|
||||||
ids := segmentIDs.GetData()
|
ids := segmentIDs.GetData()
|
||||||
assert.NotEmpty(t, segmentIDs)
|
assert.NotEmpty(t, segmentIDs)
|
||||||
|
|
||||||
|
@ -142,52 +105,14 @@ func TestRangeSearchIP(t *testing.T) {
|
||||||
for _, segment := range segments {
|
for _, segment := range segments {
|
||||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||||
}
|
}
|
||||||
|
waitingForFlush(ctx, c, ids)
|
||||||
if has && len(ids) > 0 {
|
|
||||||
flushed := func() bool {
|
|
||||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
|
||||||
SegmentIDs: ids,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
//panic(errors.New("GetFlushState failed"))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return resp.GetFlushed()
|
|
||||||
}
|
|
||||||
for !flushed() {
|
|
||||||
// respect context deadline/cancel
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
panic(errors.New("deadline exceeded"))
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create index
|
// create index
|
||||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
FieldName: floatVecField,
|
FieldName: floatVecField,
|
||||||
IndexName: "_default",
|
IndexName: "_default",
|
||||||
ExtraParams: []*commonpb.KeyValuePair{
|
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP),
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: distance.IP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "index_type",
|
|
||||||
Value: "IVF_FLAT",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "nlist",
|
|
||||||
Value: strconv.Itoa(10),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = merr.Error(createIndexStatus)
|
err = merr.Error(createIndexStatus)
|
||||||
|
@ -205,34 +130,21 @@ func TestRangeSearchIP(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||||
}
|
}
|
||||||
for {
|
waitingForLoad(ctx, c, collectionName)
|
||||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
|
||||||
CollectionName: collectionName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic("GetLoadingProgress fail")
|
|
||||||
}
|
|
||||||
if loadProgress.GetProgress() == 100 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
// search
|
// search
|
||||||
expr := fmt.Sprintf("%s > 0", "int64")
|
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||||
nq := 10
|
nq := 10
|
||||||
topk := 10
|
topk := 10
|
||||||
roundDecimal := -1
|
roundDecimal := -1
|
||||||
nprobe := 10
|
|
||||||
radius := 10
|
radius := 10
|
||||||
filter := 20
|
filter := 20
|
||||||
|
|
||||||
params := make(map[string]int)
|
params := getSearchParams(IndexFaissIvfFlat, distance.IP)
|
||||||
params["nprobe"] = nprobe
|
|
||||||
|
|
||||||
// only pass in radius when range search
|
// only pass in radius when range search
|
||||||
params["radius"] = radius
|
params["radius"] = radius
|
||||||
searchReq := constructSearchRequest("", collectionName, expr,
|
searchReq := constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
@ -245,7 +157,7 @@ func TestRangeSearchIP(t *testing.T) {
|
||||||
// pass in radius and range_filter when range search
|
// pass in radius and range_filter when range search
|
||||||
params["range_filter"] = filter
|
params["range_filter"] = filter
|
||||||
searchReq = constructSearchRequest("", collectionName, expr,
|
searchReq = constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
@ -259,7 +171,7 @@ func TestRangeSearchIP(t *testing.T) {
|
||||||
params["radius"] = filter
|
params["radius"] = filter
|
||||||
params["range_filter"] = radius
|
params["range_filter"] = radius
|
||||||
searchReq = constructSearchRequest("", collectionName, expr,
|
searchReq = constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
@ -277,59 +189,24 @@ func TestRangeSearchIP(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRangeSearchL2(t *testing.T) {
|
func TestRangeSearchL2(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||||
c, err := StartMiniCluster(ctx)
|
c, err := StartMiniCluster(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = c.Start()
|
err = c.Start()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer c.Stop()
|
defer func() {
|
||||||
|
err = c.Stop()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
prefix := "TestRangeSearchL2"
|
prefix := "TestRangeSearchL2"
|
||||||
dbName := ""
|
dbName := ""
|
||||||
collectionName := prefix + funcutil.GenRandomStr()
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
int64Field := "int64"
|
|
||||||
floatVecField := "fvec"
|
|
||||||
dim := 128
|
dim := 128
|
||||||
rowNum := 3000
|
rowNum := 3000
|
||||||
|
|
||||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
schema := constructSchema(collectionName, dim, true)
|
||||||
pk := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
|
||||||
Name: int64Field,
|
|
||||||
IsPrimaryKey: true,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_Int64,
|
|
||||||
TypeParams: nil,
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: true,
|
|
||||||
}
|
|
||||||
fVec := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
|
||||||
Name: floatVecField,
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_FloatVector,
|
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: false,
|
|
||||||
}
|
|
||||||
return &schemapb.CollectionSchema{
|
|
||||||
Name: collectionName,
|
|
||||||
Description: "",
|
|
||||||
AutoID: false,
|
|
||||||
Fields: []*schemapb.FieldSchema{
|
|
||||||
pk,
|
|
||||||
fVec,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
schema := constructCollectionSchema()
|
|
||||||
marshaledSchema, err := proto.Marshal(schema)
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -371,6 +248,7 @@ func TestRangeSearchL2(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||||
|
assert.True(t, has)
|
||||||
ids := segmentIDs.GetData()
|
ids := segmentIDs.GetData()
|
||||||
assert.NotEmpty(t, segmentIDs)
|
assert.NotEmpty(t, segmentIDs)
|
||||||
|
|
||||||
|
@ -380,52 +258,14 @@ func TestRangeSearchL2(t *testing.T) {
|
||||||
for _, segment := range segments {
|
for _, segment := range segments {
|
||||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||||
}
|
}
|
||||||
|
waitingForFlush(ctx, c, ids)
|
||||||
if has && len(ids) > 0 {
|
|
||||||
flushed := func() bool {
|
|
||||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
|
||||||
SegmentIDs: ids,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
//panic(errors.New("GetFlushState failed"))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return resp.GetFlushed()
|
|
||||||
}
|
|
||||||
for !flushed() {
|
|
||||||
// respect context deadline/cancel
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
panic(errors.New("deadline exceeded"))
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create index
|
// create index
|
||||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
FieldName: floatVecField,
|
FieldName: floatVecField,
|
||||||
IndexName: "_default",
|
IndexName: "_default",
|
||||||
ExtraParams: []*commonpb.KeyValuePair{
|
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2),
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: distance.L2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "index_type",
|
|
||||||
Value: "IVF_FLAT",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "nlist",
|
|
||||||
Value: strconv.Itoa(10),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = merr.Error(createIndexStatus)
|
err = merr.Error(createIndexStatus)
|
||||||
|
@ -443,34 +283,20 @@ func TestRangeSearchL2(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||||
}
|
}
|
||||||
for {
|
waitingForLoad(ctx, c, collectionName)
|
||||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
|
||||||
CollectionName: collectionName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic("GetLoadingProgress fail")
|
|
||||||
}
|
|
||||||
if loadProgress.GetProgress() == 100 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
// search
|
// search
|
||||||
expr := fmt.Sprintf("%s > 0", "int64")
|
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||||
nq := 10
|
nq := 10
|
||||||
topk := 10
|
topk := 10
|
||||||
roundDecimal := -1
|
roundDecimal := -1
|
||||||
nprobe := 10
|
|
||||||
radius := 20
|
radius := 20
|
||||||
filter := 10
|
filter := 10
|
||||||
|
|
||||||
params := make(map[string]int)
|
params := getSearchParams(IndexFaissIvfFlat, distance.L2)
|
||||||
params["nprobe"] = nprobe
|
|
||||||
|
|
||||||
// only pass in radius when range search
|
// only pass in radius when range search
|
||||||
params["radius"] = radius
|
params["radius"] = radius
|
||||||
searchReq := constructSearchRequest("", collectionName, expr,
|
searchReq := constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
@ -483,7 +309,7 @@ func TestRangeSearchL2(t *testing.T) {
|
||||||
// pass in radius and range_filter when range search
|
// pass in radius and range_filter when range search
|
||||||
params["range_filter"] = filter
|
params["range_filter"] = filter
|
||||||
searchReq = constructSearchRequest("", collectionName, expr,
|
searchReq = constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
@ -497,7 +323,7 @@ func TestRangeSearchL2(t *testing.T) {
|
||||||
params["radius"] = filter
|
params["radius"] = filter
|
||||||
params["range_filter"] = radius
|
params["range_filter"] = radius
|
||||||
searchReq = constructSearchRequest("", collectionName, expr,
|
searchReq = constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
|
|
@ -19,13 +19,10 @@ package integration
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
@ -38,59 +35,25 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpsert(t *testing.T) {
|
func TestUpsert(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||||
c, err := StartMiniCluster(ctx)
|
c, err := StartMiniCluster(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = c.Start()
|
err = c.Start()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer c.Stop()
|
defer func() {
|
||||||
|
err = c.Stop()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
prefix := "TestUpsert"
|
prefix := "TestUpsert"
|
||||||
dbName := ""
|
dbName := ""
|
||||||
collectionName := prefix + funcutil.GenRandomStr()
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
int64Field := "int64"
|
|
||||||
floatVecField := "fvec"
|
|
||||||
dim := 128
|
dim := 128
|
||||||
rowNum := 3000
|
rowNum := 3000
|
||||||
|
|
||||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
schema := constructSchema(collectionName, dim, false)
|
||||||
pk := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
|
||||||
Name: int64Field,
|
|
||||||
IsPrimaryKey: true,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_Int64,
|
|
||||||
TypeParams: nil,
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: false,
|
|
||||||
}
|
|
||||||
fVec := &schemapb.FieldSchema{
|
|
||||||
FieldID: 0,
|
|
||||||
Name: floatVecField,
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
Description: "",
|
|
||||||
DataType: schemapb.DataType_FloatVector,
|
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IndexParams: nil,
|
|
||||||
AutoID: false,
|
|
||||||
}
|
|
||||||
return &schemapb.CollectionSchema{
|
|
||||||
Name: collectionName,
|
|
||||||
Description: "",
|
|
||||||
AutoID: false,
|
|
||||||
Fields: []*schemapb.FieldSchema{
|
|
||||||
pk,
|
|
||||||
fVec,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
schema := constructCollectionSchema()
|
|
||||||
marshaledSchema, err := proto.Marshal(schema)
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -113,7 +76,7 @@ func TestUpsert(t *testing.T) {
|
||||||
assert.True(t, merr.Ok(showCollectionsResp.GetStatus()))
|
assert.True(t, merr.Ok(showCollectionsResp.GetStatus()))
|
||||||
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
|
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
|
||||||
|
|
||||||
pkFieldData := newInt64PrimaryKey(int64Field, rowNum)
|
pkFieldData := newInt64FieldData(int64Field, rowNum)
|
||||||
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
||||||
hashKeys := generateHashKeys(rowNum)
|
hashKeys := generateHashKeys(rowNum)
|
||||||
upsertResult, err := c.proxy.Upsert(ctx, &milvuspb.UpsertRequest{
|
upsertResult, err := c.proxy.Upsert(ctx, &milvuspb.UpsertRequest{
|
||||||
|
@ -133,6 +96,7 @@ func TestUpsert(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||||
|
assert.True(t, has)
|
||||||
ids := segmentIDs.GetData()
|
ids := segmentIDs.GetData()
|
||||||
assert.NotEmpty(t, segmentIDs)
|
assert.NotEmpty(t, segmentIDs)
|
||||||
|
|
||||||
|
@ -142,52 +106,14 @@ func TestUpsert(t *testing.T) {
|
||||||
for _, segment := range segments {
|
for _, segment := range segments {
|
||||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||||
}
|
}
|
||||||
|
waitingForFlush(ctx, c, ids)
|
||||||
if has && len(ids) > 0 {
|
|
||||||
flushed := func() bool {
|
|
||||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
|
||||||
SegmentIDs: ids,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
//panic(errors.New("GetFlushState failed"))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return resp.GetFlushed()
|
|
||||||
}
|
|
||||||
for !flushed() {
|
|
||||||
// respect context deadline/cancel
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
panic(errors.New("deadline exceeded"))
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create index
|
// create index
|
||||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
FieldName: floatVecField,
|
FieldName: floatVecField,
|
||||||
IndexName: "_default",
|
IndexName: "_default",
|
||||||
ExtraParams: []*commonpb.KeyValuePair{
|
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP),
|
||||||
{
|
|
||||||
Key: "dim",
|
|
||||||
Value: strconv.Itoa(dim),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: distance.L2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "index_type",
|
|
||||||
Value: "IVF_FLAT",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: "nlist",
|
|
||||||
Value: strconv.Itoa(10),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = merr.Error(createIndexStatus)
|
err = merr.Error(createIndexStatus)
|
||||||
|
@ -205,30 +131,16 @@ func TestUpsert(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||||
}
|
}
|
||||||
for {
|
waitingForLoad(ctx, c, collectionName)
|
||||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
|
||||||
CollectionName: collectionName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic("GetLoadingProgress fail")
|
|
||||||
}
|
|
||||||
if loadProgress.GetProgress() == 100 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}
|
|
||||||
// search
|
// search
|
||||||
expr := fmt.Sprintf("%s > 0", "int64")
|
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||||
nq := 10
|
nq := 10
|
||||||
topk := 10
|
topk := 10
|
||||||
roundDecimal := -1
|
roundDecimal := -1
|
||||||
nprobe := 10
|
|
||||||
|
|
||||||
params := make(map[string]int)
|
|
||||||
params["nprobe"] = nprobe
|
|
||||||
|
|
||||||
|
params := getSearchParams(IndexFaissIvfFlat, "")
|
||||||
searchReq := constructSearchRequest("", collectionName, expr,
|
searchReq := constructSearchRequest("", collectionName, expr,
|
||||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||||
|
|
||||||
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat
|
||||||
|
IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ
|
||||||
|
IndexFaissIDMap = indexparamcheck.IndexFaissIDMap
|
||||||
|
IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat
|
||||||
|
IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ
|
||||||
|
IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8
|
||||||
|
IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap
|
||||||
|
IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat
|
||||||
|
IndexHNSW = indexparamcheck.IndexHNSW
|
||||||
|
IndexDISKANN = indexparamcheck.IndexDISKANN
|
||||||
|
)
|
||||||
|
|
||||||
|
func constructIndexParam(dim int, indexType string, metricType string) []*commonpb.KeyValuePair {
|
||||||
|
params := []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.DimKey,
|
||||||
|
Value: strconv.Itoa(dim),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: common.MetricTypeKey,
|
||||||
|
Value: metricType,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: common.IndexTypeKey,
|
||||||
|
Value: indexType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
switch indexType {
|
||||||
|
case IndexFaissIDMap, IndexFaissBinIDMap:
|
||||||
|
// no index param is required
|
||||||
|
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8:
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: "nlist",
|
||||||
|
Value: "100",
|
||||||
|
})
|
||||||
|
case IndexFaissIvfPQ:
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: "nlist",
|
||||||
|
Value: "100",
|
||||||
|
})
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: "m",
|
||||||
|
Value: "16",
|
||||||
|
})
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: "nbits",
|
||||||
|
Value: "8",
|
||||||
|
})
|
||||||
|
case IndexHNSW:
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: "M",
|
||||||
|
Value: "16",
|
||||||
|
})
|
||||||
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
|
Key: "efConstruction",
|
||||||
|
Value: "200",
|
||||||
|
})
|
||||||
|
case IndexDISKANN:
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unimplemented index param for %s, please help to improve it", indexType))
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSearchParams(indexType string, metricType string) map[string]any {
|
||||||
|
params := make(map[string]any)
|
||||||
|
switch indexType {
|
||||||
|
case IndexFaissIDMap, IndexFaissBinIDMap:
|
||||||
|
params["metric_type"] = metricType
|
||||||
|
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ:
|
||||||
|
params["nprobe"] = 8
|
||||||
|
case IndexHNSW:
|
||||||
|
params["ef"] = 200
|
||||||
|
case IndexDISKANN:
|
||||||
|
params["search_list"] = 5
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unimplemented search param for %s, please help to improve it", indexType))
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
|
@ -0,0 +1,154 @@
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func waitingForFlush(ctx context.Context, cluster *MiniCluster, segIDs []int64) {
|
||||||
|
flushed := func() bool {
|
||||||
|
resp, err := cluster.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
||||||
|
SegmentIDs: segIDs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return resp.GetFlushed()
|
||||||
|
}
|
||||||
|
for !flushed() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
panic("flush timeout")
|
||||||
|
default:
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInt64FieldData(fieldName string, numRows int) *schemapb.FieldData {
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
FieldName: fieldName,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{
|
||||||
|
Data: generateInt64Array(numRows),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStringFieldData(fieldName string, numRows int) *schemapb.FieldData {
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
FieldName: fieldName,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: generateStringArray(numRows),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_FloatVector,
|
||||||
|
FieldName: fieldName,
|
||||||
|
Field: &schemapb.FieldData_Vectors{
|
||||||
|
Vectors: &schemapb.VectorField{
|
||||||
|
Dim: int64(dim),
|
||||||
|
Data: &schemapb.VectorField_FloatVector{
|
||||||
|
FloatVector: &schemapb.FloatArray{
|
||||||
|
Data: generateFloatVectors(numRows, dim),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_BinaryVector,
|
||||||
|
FieldName: fieldName,
|
||||||
|
Field: &schemapb.FieldData_Vectors{
|
||||||
|
Vectors: &schemapb.VectorField{
|
||||||
|
Dim: int64(dim),
|
||||||
|
Data: &schemapb.VectorField_BinaryVector{
|
||||||
|
BinaryVector: generateBinaryVectors(numRows, dim),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateInt64Array(numRows int) []int64 {
|
||||||
|
ret := make([]int64, numRows)
|
||||||
|
for i := 0; i < numRows; i++ {
|
||||||
|
ret[i] = int64(i)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateStringArray(numRows int) []string {
|
||||||
|
ret := make([]string, numRows)
|
||||||
|
for i := 0; i < numRows; i++ {
|
||||||
|
ret[i] = fmt.Sprintf("%d", i)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateFloatVectors(numRows, dim int) []float32 {
|
||||||
|
total := numRows * dim
|
||||||
|
ret := make([]float32, 0, total)
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
ret = append(ret, rand.Float32())
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateBinaryVectors(numRows, dim int) []byte {
|
||||||
|
total := (numRows * dim) / 8
|
||||||
|
ret := make([]byte, total)
|
||||||
|
_, err := rand.Read(ret)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateHashKeys(numRows int) []uint32 {
|
||||||
|
ret := make([]uint32, 0, numRows)
|
||||||
|
for i := 0; i < numRows; i++ {
|
||||||
|
ret = append(ret, rand.Uint32())
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
|
@ -0,0 +1,166 @@
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"math/rand"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AnnsFieldKey = "anns_field"
|
||||||
|
TopKKey = "topk"
|
||||||
|
NQKey = "nq"
|
||||||
|
MetricTypeKey = "metric_type"
|
||||||
|
SearchParamsKey = "params"
|
||||||
|
RoundDecimalKey = "round_decimal"
|
||||||
|
OffsetKey = "offset"
|
||||||
|
LimitKey = "limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
func waitingForLoad(ctx context.Context, cluster *MiniCluster, collection string) {
|
||||||
|
getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
|
||||||
|
loadProgress, err := cluster.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||||
|
CollectionName: collection,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic("GetLoadingProgress fail")
|
||||||
|
}
|
||||||
|
return loadProgress
|
||||||
|
}
|
||||||
|
for getLoadingProgress().GetProgress() != 100 {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
panic("load timeout")
|
||||||
|
default:
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func constructSearchRequest(
|
||||||
|
dbName, collectionName string,
|
||||||
|
expr string,
|
||||||
|
vecField string,
|
||||||
|
vectorType schemapb.DataType,
|
||||||
|
outputFields []string,
|
||||||
|
metricType string,
|
||||||
|
params map[string]any,
|
||||||
|
nq, dim int, topk, roundDecimal int,
|
||||||
|
) *milvuspb.SearchRequest {
|
||||||
|
b, err := json.Marshal(params)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
plg := constructPlaceholderGroup(nq, dim, vectorType)
|
||||||
|
plgBs, err := proto.Marshal(plg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &milvuspb.SearchRequest{
|
||||||
|
Base: nil,
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
PartitionNames: nil,
|
||||||
|
Dsl: expr,
|
||||||
|
PlaceholderGroup: plgBs,
|
||||||
|
DslType: commonpb.DslType_BoolExprV1,
|
||||||
|
OutputFields: outputFields,
|
||||||
|
SearchParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.MetricTypeKey,
|
||||||
|
Value: metricType,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: string(b),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: vecField,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: common.TopKKey,
|
||||||
|
Value: strconv.Itoa(topk),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: RoundDecimalKey,
|
||||||
|
Value: strconv.Itoa(roundDecimal),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TravelTimestamp: 0,
|
||||||
|
GuaranteeTimestamp: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup {
|
||||||
|
values := make([][]byte, 0, nq)
|
||||||
|
var placeholderType commonpb.PlaceholderType
|
||||||
|
switch vectorType {
|
||||||
|
case schemapb.DataType_FloatVector:
|
||||||
|
placeholderType = commonpb.PlaceholderType_FloatVector
|
||||||
|
for i := 0; i < nq; i++ {
|
||||||
|
bs := make([]byte, 0, dim*4)
|
||||||
|
for j := 0; j < dim; j++ {
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
f := rand.Float32()
|
||||||
|
err := binary.Write(&buffer, common.Endian, f)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
bs = append(bs, buffer.Bytes()...)
|
||||||
|
}
|
||||||
|
values = append(values, bs)
|
||||||
|
}
|
||||||
|
case schemapb.DataType_BinaryVector:
|
||||||
|
placeholderType = commonpb.PlaceholderType_BinaryVector
|
||||||
|
for i := 0; i < nq; i++ {
|
||||||
|
total := dim / 8
|
||||||
|
ret := make([]byte, total)
|
||||||
|
_, err := rand.Read(ret)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
values = append(values, ret)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("invalid vector data type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &commonpb.PlaceholderGroup{
|
||||||
|
Placeholders: []*commonpb.PlaceholderValue{
|
||||||
|
{
|
||||||
|
Tag: "$0",
|
||||||
|
Type: placeholderType,
|
||||||
|
Values: values,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
boolField = "boolField"
|
||||||
|
int8Field = "int8Field"
|
||||||
|
int16Field = "int16Field"
|
||||||
|
int32Field = "int32Field"
|
||||||
|
int64Field = "int64Field"
|
||||||
|
floatField = "floatField"
|
||||||
|
doubleField = "doubleField"
|
||||||
|
varCharField = "varCharField"
|
||||||
|
floatVecField = "floatVecField"
|
||||||
|
binVecField = "binVecField"
|
||||||
|
)
|
||||||
|
|
||||||
|
func constructSchema(collection string, dim int, autoID bool, fields ...*schemapb.FieldSchema) *schemapb.CollectionSchema {
|
||||||
|
// if fields are specified, construct it
|
||||||
|
if len(fields) > 0 {
|
||||||
|
return &schemapb.CollectionSchema{
|
||||||
|
Name: collection,
|
||||||
|
AutoID: autoID,
|
||||||
|
Fields: fields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no field is specified, use default
|
||||||
|
pk := &schemapb.FieldSchema{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: int64Field,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Description: "",
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
TypeParams: nil,
|
||||||
|
IndexParams: nil,
|
||||||
|
AutoID: autoID,
|
||||||
|
}
|
||||||
|
fVec := &schemapb.FieldSchema{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: floatVecField,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
Description: "",
|
||||||
|
DataType: schemapb.DataType_FloatVector,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.DimKey,
|
||||||
|
Value: fmt.Sprintf("%d", dim),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IndexParams: nil,
|
||||||
|
}
|
||||||
|
return &schemapb.CollectionSchema{
|
||||||
|
Name: collection,
|
||||||
|
AutoID: autoID,
|
||||||
|
Fields: []*schemapb.FieldSchema{pk, fVec},
|
||||||
|
}
|
||||||
|
}
|
|
@ -842,11 +842,7 @@ class TestCollectionSearchInvalid(TestcaseBase):
|
||||||
log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name)
|
log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name)
|
||||||
collection_w.search(vectors[:default_nq], default_search_field,
|
collection_w.search(vectors[:default_nq], default_search_field,
|
||||||
default_search_params, default_limit,
|
default_search_params, default_limit,
|
||||||
default_search_exp, output_fields=output_fields,
|
default_search_exp, output_fields=output_fields)
|
||||||
check_task=CheckTasks.err_res,
|
|
||||||
check_items={"err_code": 1,
|
|
||||||
"err_msg": "Search doesn't support "
|
|
||||||
"vector field as output_fields"})
|
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
@pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])
|
@pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])
|
||||||
|
|
Loading…
Reference in New Issue