From 092d743917457e84926a60ecfdb4bf4bd9594fd4 Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Sun, 23 Apr 2023 09:00:32 +0800 Subject: [PATCH] Add support for getting vectors by ids (#23450) Signed-off-by: bigsheeper --- Makefile | 3 + internal/core/src/common/Utils.h | 10 + internal/core/src/index/VectorDiskIndex.cpp | 32 + internal/core/src/index/VectorDiskIndex.h | 8 + internal/core/src/index/VectorIndex.h | 7 + internal/core/src/index/VectorMemIndex.cpp | 30 + internal/core/src/index/VectorMemIndex.h | 7 + .../core/src/segcore/SegmentGrowingImpl.h | 5 + internal/core/src/segcore/SegmentInterface.h | 3 + .../core/src/segcore/SegmentSealedImpl.cpp | 52 +- internal/core/src/segcore/SegmentSealedImpl.h | 6 + internal/core/src/segcore/segment_c.cpp | 7 + internal/core/src/segcore/segment_c.h | 3 + internal/core/unittest/test_indexing.cpp | 85 + internal/core/unittest/test_sealed.cpp | 49 + internal/core/unittest/test_utils/DataGen.h | 11 + internal/mocks/mock_proxy.go | 4086 +++++++++++++++++ internal/proxy/impl.go | 3 +- internal/proxy/task_query.go | 6 +- internal/proxy/task_query_test.go | 25 + internal/proxy/task_search.go | 162 +- internal/proxy/task_search_test.go | 287 +- internal/querynodev2/segments/segment.go | 23 +- internal/querynodev2/segments/segment_test.go | 7 + pkg/util/typeutil/schema.go | 51 + pkg/util/typeutil/schema_test.go | 106 + tests/integration/bulkinsert_test.go | 108 +- tests/integration/get_vector_test.go | 366 ++ tests/integration/hello_milvus_test.go | 284 +- tests/integration/range_search_test.go | 238 +- tests/integration/upsert_test.go | 118 +- tests/integration/util_index.go | 108 + tests/integration/util_insert.go | 154 + tests/integration/util_query.go | 166 + tests/integration/util_schema.go | 80 + tests/python_client/testcases/test_search.py | 6 +- 36 files changed, 5997 insertions(+), 705 deletions(-) create mode 100644 internal/mocks/mock_proxy.go create mode 100644 tests/integration/get_vector_test.go create mode 100644 tests/integration/util_index.go create mode 100644 tests/integration/util_insert.go create mode 100644 tests/integration/util_query.go create mode 100644 tests/integration/util_schema.go diff --git a/Makefile b/Makefile index b048086066..51fa5f5581 100644 --- a/Makefile +++ b/Makefile @@ -318,6 +318,9 @@ rpm: install @cp -r build/rpm/services ~/rpmbuild/BUILD/ @QA_RPATHS="$$[ 0x001|0x0002|0x0020 ]" rpmbuild -ba ./build/rpm/milvus.spec +mock-proxy: + mockery --name=ProxyComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy.go --structname=Proxy --with-expecter + mock-datanode: mockery --name=DataNode --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode.go --with-expecter diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index 8a0cb28d3c..40522d2eba 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -85,6 +85,16 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) { return true; } +inline DatasetPtr +GenIdsDataset(const int64_t count, const int64_t* ids) { + auto ret_ds = std::make_shared(); + ret_ds->SetRows(count); + ret_ds->SetDim(1); + ret_ds->SetIds(ids); + ret_ds->SetIsOwner(false); + return ret_ds; +} + inline DatasetPtr GenResultDataset(const int64_t nq, const int64_t topk, diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index 66fc5ea6cb..7729d9ecf8 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -230,6 +230,38 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, return result; } +template +const bool +VectorDiskAnnIndex::HasRawData() const { + return index_.HasRawData(GetMetricType()); +} + +template +const std::vector +VectorDiskAnnIndex::GetVector(const DatasetPtr dataset, + const Config& config) const { + auto res = index_.GetVectorByIds(*dataset, config); + if (!res.has_value()) { + PanicCodeInfo( + ErrorCodeEnum::UnexpectedError, + "failed to get vector, " + MatchKnowhereError(res.error())); + } + auto index_type = GetIndexType(); + auto tensor = res.value()->GetTensor(); + auto row_num = res.value()->GetRows(); + auto dim = res.value()->GetDim(); + int64_t data_size; + if (is_in_bin_list(index_type)) { + data_size = dim / 8 * row_num; + } else { + data_size = dim * row_num * sizeof(float); + } + std::vector raw_data; + raw_data.resize(data_size); + memcpy(raw_data.data(), tensor, data_size); + return raw_data; +} + template void VectorDiskAnnIndex::CleanLocalData() { diff --git a/internal/core/src/index/VectorDiskIndex.h b/internal/core/src/index/VectorDiskIndex.h index 2425ad5f16..351fee4cf5 100644 --- a/internal/core/src/index/VectorDiskIndex.h +++ b/internal/core/src/index/VectorDiskIndex.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include "index/VectorIndex.h" #include "storage/DiskFileManagerImpl.h" @@ -60,6 +61,13 @@ class VectorDiskAnnIndex : public VectorIndex { const SearchInfo& search_info, const BitsetView& bitset) override; + const bool + HasRawData() const override; + + const std::vector + GetVector(const DatasetPtr dataset, + const Config& config = {}) const override; + void CleanLocalData() override; diff --git a/internal/core/src/index/VectorIndex.h b/internal/core/src/index/VectorIndex.h index e70ee79564..b92bddca93 100644 --- a/internal/core/src/index/VectorIndex.h +++ b/internal/core/src/index/VectorIndex.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "knowhere/factory.h" @@ -50,6 +51,12 @@ class VectorIndex : public IndexBase { const SearchInfo& search_info, const BitsetView& bitset) = 0; + virtual const bool + HasRawData() const = 0; + + virtual const std::vector + GetVector(const DatasetPtr dataset, const Config& config = {}) const = 0; + IndexType GetIndexType() const { return index_type_; diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 2acc756565..4849c51d65 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -145,4 +145,34 @@ VectorMemIndex::Query(const DatasetPtr dataset, return result; } +const bool +VectorMemIndex::HasRawData() const { + return index_.HasRawData(GetMetricType()); +} + +const std::vector +VectorMemIndex::GetVector(const DatasetPtr dataset, + const Config& config) const { + auto res = index_.GetVectorByIds(*dataset, config); + if (!res.has_value()) { + PanicCodeInfo( + ErrorCodeEnum::UnexpectedError, + "failed to get vector, " + MatchKnowhereError(res.error())); + } + auto index_type = GetIndexType(); + auto tensor = res.value()->GetTensor(); + auto row_num = res.value()->GetRows(); + auto dim = res.value()->GetDim(); + int64_t data_size; + if (is_in_bin_list(index_type)) { + data_size = dim / 8 * row_num; + } else { + data_size = dim * row_num * sizeof(float); + } + std::vector raw_data; + raw_data.resize(data_size); + memcpy(raw_data.data(), tensor, data_size); + return raw_data; +} + } // namespace milvus::index diff --git a/internal/core/src/index/VectorMemIndex.h b/internal/core/src/index/VectorMemIndex.h index dbd91e14eb..e706f34f20 100644 --- a/internal/core/src/index/VectorMemIndex.h +++ b/internal/core/src/index/VectorMemIndex.h @@ -51,6 +51,13 @@ class VectorMemIndex : public VectorIndex { const SearchInfo& search_info, const BitsetView& bitset) override; + const bool + HasRawData() const override; + + const std::vector + GetVector(const DatasetPtr dataset, + const Config& config = {}) const override; + protected: Config config_; knowhere::Index index_; diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 5665a9d7dd..51a76b6ed5 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -223,6 +223,11 @@ class SegmentGrowingImpl : public SegmentGrowing { return true; } + bool + HasRawData(int64_t field_id) const override { + return true; + } + protected: int64_t num_chunk() const override; diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index 29a201096d..fd6bff82da 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -88,6 +88,9 @@ class SegmentInterface { virtual SegmentType type() const = 0; + + virtual bool + HasRawData(int64_t field_id) const = 0; }; // internal API for DSL calculation diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 7ec0626a22..a86b5c61df 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -29,6 +29,7 @@ #include "query/ScalarIndex.h" #include "query/SearchBruteForce.h" #include "query/SearchOnSealed.h" +#include "index/Utils.h" namespace milvus::segcore { @@ -475,6 +476,35 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info, } } +std::unique_ptr +SegmentSealedImpl::get_vector(FieldId field_id, + const int64_t* ids, + int64_t count) const { + auto& filed_meta = schema_->operator[](field_id); + AssertInfo(filed_meta.is_vector(), "vector field is not vector type"); + + if (get_bit(index_ready_bitset_, field_id)) { + AssertInfo(vector_indexings_.is_ready(field_id), + "vector index is not ready"); + auto field_indexing = vector_indexings_.get_field_indexing(field_id); + auto vec_index = + dynamic_cast(field_indexing->indexing_.get()); + + auto index_type = vec_index->GetIndexType(); + auto metric_type = vec_index->GetMetricType(); + auto has_raw_data = vec_index->HasRawData(); + + if (has_raw_data) { + auto ids_ds = GenIdsDataset(count, ids); + auto& vector = vec_index->GetVector(ids_ds); + return segcore::CreateVectorDataArrayFrom( + vector.data(), count, filed_meta); + } + } + + return fill_with_empty(field_id, count); +} + void SegmentSealedImpl::DropFieldData(const FieldId field_id) { if (SystemProperty::Instance().IsSystem(field_id)) { @@ -666,9 +696,7 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, return ReverseDataFromIndex(index, seg_offsets, count, field_meta); } - // TODO: knowhere support reverse data from vector index - // Now, real data will be filled in data array using chunk manager - return fill_with_empty(field_id, count); + return get_vector(field_id, seg_offsets, count); } Assert(get_bit(field_data_ready_bitset_, field_id)); @@ -783,6 +811,24 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const { } } +bool +SegmentSealedImpl::HasRawData(int64_t field_id) const { + std::shared_lock lck(mutex_); + auto fieldID = FieldId(field_id); + const auto& field_meta = schema_->operator[](fieldID); + if (datatype_is_vector(field_meta.get_data_type())) { + if (get_bit(index_ready_bitset_, fieldID)) { + AssertInfo(vector_indexings_.is_ready(fieldID), + "vector index is not ready"); + auto field_indexing = vector_indexings_.get_field_indexing(fieldID); + auto vec_index = dynamic_cast( + field_indexing->indexing_.get()); + return vec_index->HasRawData(); + } + } + return true; +} + std::pair, std::vector> SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) const { diff --git a/internal/core/src/segcore/SegmentSealedImpl.h b/internal/core/src/segcore/SegmentSealedImpl.h index e9732d7752..87d9afb8e6 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.h +++ b/internal/core/src/segcore/SegmentSealedImpl.h @@ -61,6 +61,9 @@ class SegmentSealedImpl : public SegmentSealed { return id_; } + bool + HasRawData(int64_t field_id) const override; + public: int64_t GetMemoryUsageInBytes() const override; @@ -74,6 +77,9 @@ class SegmentSealedImpl : public SegmentSealed { const Schema& get_schema() const override; + std::unique_ptr + get_vector(FieldId field_id, const int64_t* ids, int64_t count) const; + public: int64_t num_chunk_index(FieldId field_id) const override; diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index 30d264c9ed..09d096fc85 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -164,6 +164,13 @@ GetRealCount(CSegmentInterface c_segment) { return segment->get_real_count(); } +bool +HasRawData(CSegmentInterface c_segment, int64_t field_id) { + auto segment = + reinterpret_cast(c_segment); + return segment->HasRawData(field_id); +} + ////////////////////////////// interfaces for growing segment ////////////////////////////// CStatus Insert(CSegmentInterface c_segment, diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index 17ec12ae28..e87222990f 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -67,6 +67,9 @@ GetDeletedCount(CSegmentInterface c_segment); int64_t GetRealCount(CSegmentInterface c_segment); +bool +HasRawData(CSegmentInterface c_segment, int64_t field_id); + ////////////////////////////// interfaces for growing segment ////////////////////////////// CStatus Insert(CSegmentInterface c_segment, diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 43fe28be18..a8b333b708 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -449,6 +449,91 @@ TEST_P(IndexTest, BuildAndQuery) { vec_index->Query(xq_dataset, search_info, nullptr); } +TEST_P(IndexTest, GetVector) { + milvus::index::CreateIndexInfo create_index_info; + create_index_info.index_type = index_type; + create_index_info.metric_type = metric_type; + create_index_info.field_type = vec_field_data_type; + index::IndexBasePtr index; + + if (index_type == knowhere::IndexEnum::INDEX_DISKANN) { +#ifdef BUILD_DISK_ANN + milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100}; + milvus::storage::IndexMeta index_meta{3, 100, 1000, 1}; + auto file_manager = + std::make_shared( + field_data_meta, index_meta, storage_config_); + index = milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, file_manager); +#endif + } else { + index = milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, nullptr); + } + ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf)); + milvus::index::IndexBasePtr new_index; + milvus::index::VectorIndex* vec_index = nullptr; + + if (index_type == knowhere::IndexEnum::INDEX_DISKANN) { +#ifdef BUILD_DISK_ANN + // TODO ::diskann.query need load first, ugly + auto binary_set = index->Serialize(milvus::Config{}); + index.reset(); + milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100}; + milvus::storage::IndexMeta index_meta{3, 100, 1000, 1}; + auto file_manager = + std::make_shared( + field_data_meta, index_meta, storage_config_); + new_index = milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, file_manager); + + vec_index = dynamic_cast(new_index.get()); + + std::vector index_files; + for (auto& binary : binary_set.binary_map_) { + index_files.emplace_back(binary.first); + } + load_conf["index_files"] = index_files; + vec_index->Load(binary_set, load_conf); + EXPECT_EQ(vec_index->Count(), NB); +#endif + } else { + vec_index = dynamic_cast(index.get()); + } + EXPECT_EQ(vec_index->GetDim(), DIM); + EXPECT_EQ(vec_index->Count(), NB); + + if (!vec_index->HasRawData()) { + return; + } + + auto ids_ds = GenRandomIds(NB); + auto results = vec_index->GetVector(ids_ds); + EXPECT_TRUE(results.size() > 0); + if (!is_binary) { + std::vector result_vectors(results.size() / (sizeof(float))); + memcpy(result_vectors.data(), results.data(), results.size()); + EXPECT_TRUE(result_vectors.size() == xb_data.size()); + for (size_t i = 0; i < NB; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < DIM; ++j) { + EXPECT_TRUE(result_vectors[i * DIM + j] == + xb_data[id * DIM + j]); + } + } + } else { + EXPECT_TRUE(results.size() == xb_bin_data.size()); + const auto data_bytes = DIM / 8; + for (size_t i = 0; i < NB; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < data_bytes; ++j) { + EXPECT_TRUE(results[i * data_bytes + j] == + xb_bin_data[id * data_bytes + j]); + } + } + } +} + // #ifdef BUILD_DISK_ANN // TEST(Indexing, SearchDiskAnnWithInvalidParam) { // int64_t NB = 10000; diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index bd5fb4f178..e0b480ccc6 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -1067,3 +1067,52 @@ TEST(Sealed, RealCount) { ASSERT_TRUE(status.ok()); ASSERT_EQ(0, segment->get_real_count()); } + +TEST(Sealed, GetVector) { + auto dim = 16; + auto topK = 5; + auto N = ROW_COUNT; + auto metric_type = knowhere::metric::L2; + auto schema = std::make_shared(); + auto fakevec_id = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, dim, metric_type); + auto counter_id = schema->AddDebugField("counter", DataType::INT64); + auto double_id = schema->AddDebugField("double", DataType::DOUBLE); + auto nothing_id = schema->AddDebugField("nothing", DataType::INT32); + auto str_id = schema->AddDebugField("str", DataType::VARCHAR); + schema->AddDebugField("int8", DataType::INT8); + schema->AddDebugField("int16", DataType::INT16); + schema->AddDebugField("float", DataType::FLOAT); + schema->set_primary_field_id(counter_id); + + auto dataset = DataGen(schema, N); + + auto fakevec = dataset.get_col(fakevec_id); + + auto indexing = GenVecIndexing(N, dim, fakevec.data()); + + auto segment_sealed = CreateSealedSegment(schema); + + LoadIndexInfo vec_info; + vec_info.field_id = fakevec_id.get(); + vec_info.index = std::move(indexing); + vec_info.index_params["metric_type"] = knowhere::metric::L2; + segment_sealed->LoadIndex(vec_info); + + auto segment = dynamic_cast(segment_sealed.get()); + + auto has = segment->HasRawData(vec_info.field_id); + EXPECT_TRUE(has); + + auto ids_ds = GenRandomIds(N); + auto result = segment->get_vector(fakevec_id, ids_ds->GetIds(), N); + + auto vector = result.get()->mutable_vectors()->float_vector().data(); + EXPECT_TRUE(vector.size() == fakevec.size()); + for (size_t i = 0; i < N; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < dim; ++j) { + EXPECT_TRUE(vector[i * dim + j] == fakevec[id * dim + j]); + } + } +} diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 26344bf357..53e2e9ccc1 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -599,4 +599,15 @@ GenPKs(const std::vector& pks) { return GenPKs(pks.begin(), pks.end()); } +inline std::shared_ptr +GenRandomIds(int rows, int64_t seed = 42) { + std::mt19937 g(seed); + auto* ids = new int64_t[rows]; + for (int i = 0; i < rows; ++i) ids[i] = i; + std::shuffle(ids, ids + rows, g); + auto ids_ds = GenIdsDataset(rows, ids); + ids_ds->SetIsOwner(true); + return ids_ds; +} + } // namespace milvus::segcore diff --git a/internal/mocks/mock_proxy.go b/internal/mocks/mock_proxy.go new file mode 100644 index 0000000000..bac17bf85e --- /dev/null +++ b/internal/mocks/mock_proxy.go @@ -0,0 +1,4086 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/commonpb" + clientv3 "go.etcd.io/etcd/client/v3" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/milvuspb" + + mock "github.com/stretchr/testify/mock" + + proxypb "github.com/milvus-io/milvus/internal/proto/proxypb" + + types "github.com/milvus-io/milvus/internal/types" +) + +// Proxy is an autogenerated mock type for the ProxyComponent type +type Proxy struct { + mock.Mock +} + +type Proxy_Expecter struct { + mock *mock.Mock +} + +func (_m *Proxy) EXPECT() *Proxy_Expecter { + return &Proxy_Expecter{mock: &_m.Mock} +} + +// AlterAlias provides a mock function with given fields: ctx, request +func (_m *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterAliasRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterAliasRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_AlterAlias_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AlterAlias' +type Proxy_AlterAlias_Call struct { + *mock.Call +} + +// AlterAlias is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.AlterAliasRequest +func (_e *Proxy_Expecter) AlterAlias(ctx interface{}, request interface{}) *Proxy_AlterAlias_Call { + return &Proxy_AlterAlias_Call{Call: _e.mock.On("AlterAlias", ctx, request)} +} + +func (_c *Proxy_AlterAlias_Call) Run(run func(ctx context.Context, request *milvuspb.AlterAliasRequest)) *Proxy_AlterAlias_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.AlterAliasRequest)) + }) + return _c +} + +func (_c *Proxy_AlterAlias_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_AlterAlias_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// AlterCollection provides a mock function with given fields: ctx, request +func (_m *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterCollectionRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterCollectionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_AlterCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AlterCollection' +type Proxy_AlterCollection_Call struct { + *mock.Call +} + +// AlterCollection is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.AlterCollectionRequest +func (_e *Proxy_Expecter) AlterCollection(ctx interface{}, request interface{}) *Proxy_AlterCollection_Call { + return &Proxy_AlterCollection_Call{Call: _e.mock.On("AlterCollection", ctx, request)} +} + +func (_c *Proxy_AlterCollection_Call) Run(run func(ctx context.Context, request *milvuspb.AlterCollectionRequest)) *Proxy_AlterCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.AlterCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_AlterCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_AlterCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CalcDistance provides a mock function with given fields: ctx, request +func (_m *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.CalcDistanceResults + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CalcDistanceRequest) *milvuspb.CalcDistanceResults); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CalcDistanceResults) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CalcDistanceRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CalcDistance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CalcDistance' +type Proxy_CalcDistance_Call struct { + *mock.Call +} + +// CalcDistance is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.CalcDistanceRequest +func (_e *Proxy_Expecter) CalcDistance(ctx interface{}, request interface{}) *Proxy_CalcDistance_Call { + return &Proxy_CalcDistance_Call{Call: _e.mock.On("CalcDistance", ctx, request)} +} + +func (_c *Proxy_CalcDistance_Call) Run(run func(ctx context.Context, request *milvuspb.CalcDistanceRequest)) *Proxy_CalcDistance_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CalcDistanceRequest)) + }) + return _c +} + +func (_c *Proxy_CalcDistance_Call) Return(_a0 *milvuspb.CalcDistanceResults, _a1 error) *Proxy_CalcDistance_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CheckHealth provides a mock function with given fields: ctx, req +func (_m *Proxy) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.CheckHealthResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type Proxy_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.CheckHealthRequest +func (_e *Proxy_Expecter) CheckHealth(ctx interface{}, req interface{}) *Proxy_CheckHealth_Call { + return &Proxy_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, req)} +} + +func (_c *Proxy_CheckHealth_Call) Run(run func(ctx context.Context, req *milvuspb.CheckHealthRequest)) *Proxy_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) + }) + return _c +} + +func (_c *Proxy_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *Proxy_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CreateAlias provides a mock function with given fields: ctx, request +func (_m *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateAliasRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateAliasRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CreateAlias_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlias' +type Proxy_CreateAlias_Call struct { + *mock.Call +} + +// CreateAlias is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.CreateAliasRequest +func (_e *Proxy_Expecter) CreateAlias(ctx interface{}, request interface{}) *Proxy_CreateAlias_Call { + return &Proxy_CreateAlias_Call{Call: _e.mock.On("CreateAlias", ctx, request)} +} + +func (_c *Proxy_CreateAlias_Call) Run(run func(ctx context.Context, request *milvuspb.CreateAliasRequest)) *Proxy_CreateAlias_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CreateAliasRequest)) + }) + return _c +} + +func (_c *Proxy_CreateAlias_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_CreateAlias_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CreateCollection provides a mock function with given fields: ctx, request +func (_m *Proxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCollectionRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateCollectionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CreateCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCollection' +type Proxy_CreateCollection_Call struct { + *mock.Call +} + +// CreateCollection is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.CreateCollectionRequest +func (_e *Proxy_Expecter) CreateCollection(ctx interface{}, request interface{}) *Proxy_CreateCollection_Call { + return &Proxy_CreateCollection_Call{Call: _e.mock.On("CreateCollection", ctx, request)} +} + +func (_c *Proxy_CreateCollection_Call) Run(run func(ctx context.Context, request *milvuspb.CreateCollectionRequest)) *Proxy_CreateCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CreateCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_CreateCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_CreateCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CreateCredential provides a mock function with given fields: ctx, req +func (_m *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCredentialRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateCredentialRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CreateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCredential' +type Proxy_CreateCredential_Call struct { + *mock.Call +} + +// CreateCredential is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.CreateCredentialRequest +func (_e *Proxy_Expecter) CreateCredential(ctx interface{}, req interface{}) *Proxy_CreateCredential_Call { + return &Proxy_CreateCredential_Call{Call: _e.mock.On("CreateCredential", ctx, req)} +} + +func (_c *Proxy_CreateCredential_Call) Run(run func(ctx context.Context, req *milvuspb.CreateCredentialRequest)) *Proxy_CreateCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CreateCredentialRequest)) + }) + return _c +} + +func (_c *Proxy_CreateCredential_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_CreateCredential_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CreateIndex provides a mock function with given fields: ctx, request +func (_m *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateIndexRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateIndexRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CreateIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateIndex' +type Proxy_CreateIndex_Call struct { + *mock.Call +} + +// CreateIndex is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.CreateIndexRequest +func (_e *Proxy_Expecter) CreateIndex(ctx interface{}, request interface{}) *Proxy_CreateIndex_Call { + return &Proxy_CreateIndex_Call{Call: _e.mock.On("CreateIndex", ctx, request)} +} + +func (_c *Proxy_CreateIndex_Call) Run(run func(ctx context.Context, request *milvuspb.CreateIndexRequest)) *Proxy_CreateIndex_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CreateIndexRequest)) + }) + return _c +} + +func (_c *Proxy_CreateIndex_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_CreateIndex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CreatePartition provides a mock function with given fields: ctx, request +func (_m *Proxy) CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreatePartitionRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreatePartitionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CreatePartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreatePartition' +type Proxy_CreatePartition_Call struct { + *mock.Call +} + +// CreatePartition is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.CreatePartitionRequest +func (_e *Proxy_Expecter) CreatePartition(ctx interface{}, request interface{}) *Proxy_CreatePartition_Call { + return &Proxy_CreatePartition_Call{Call: _e.mock.On("CreatePartition", ctx, request)} +} + +func (_c *Proxy_CreatePartition_Call) Run(run func(ctx context.Context, request *milvuspb.CreatePartitionRequest)) *Proxy_CreatePartition_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CreatePartitionRequest)) + }) + return _c +} + +func (_c *Proxy_CreatePartition_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_CreatePartition_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CreateResourceGroup provides a mock function with given fields: ctx, req +func (_m *Proxy) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateResourceGroupRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateResourceGroupRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CreateResourceGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateResourceGroup' +type Proxy_CreateResourceGroup_Call struct { + *mock.Call +} + +// CreateResourceGroup is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.CreateResourceGroupRequest +func (_e *Proxy_Expecter) CreateResourceGroup(ctx interface{}, req interface{}) *Proxy_CreateResourceGroup_Call { + return &Proxy_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", ctx, req)} +} + +func (_c *Proxy_CreateResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.CreateResourceGroupRequest)) *Proxy_CreateResourceGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CreateResourceGroupRequest)) + }) + return _c +} + +func (_c *Proxy_CreateResourceGroup_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_CreateResourceGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// CreateRole provides a mock function with given fields: ctx, req +func (_m *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateRoleRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateRoleRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_CreateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateRole' +type Proxy_CreateRole_Call struct { + *mock.Call +} + +// CreateRole is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.CreateRoleRequest +func (_e *Proxy_Expecter) CreateRole(ctx interface{}, req interface{}) *Proxy_CreateRole_Call { + return &Proxy_CreateRole_Call{Call: _e.mock.On("CreateRole", ctx, req)} +} + +func (_c *Proxy_CreateRole_Call) Run(run func(ctx context.Context, req *milvuspb.CreateRoleRequest)) *Proxy_CreateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CreateRoleRequest)) + }) + return _c +} + +func (_c *Proxy_CreateRole_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_CreateRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Delete provides a mock function with given fields: ctx, request +func (_m *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.MutationResult + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteRequest) *milvuspb.MutationResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.MutationResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DeleteRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Proxy_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DeleteRequest +func (_e *Proxy_Expecter) Delete(ctx interface{}, request interface{}) *Proxy_Delete_Call { + return &Proxy_Delete_Call{Call: _e.mock.On("Delete", ctx, request)} +} + +func (_c *Proxy_Delete_Call) Run(run func(ctx context.Context, request *milvuspb.DeleteRequest)) *Proxy_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DeleteRequest)) + }) + return _c +} + +func (_c *Proxy_Delete_Call) Return(_a0 *milvuspb.MutationResult, _a1 error) *Proxy_Delete_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DeleteCredential provides a mock function with given fields: ctx, req +func (_m *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteCredentialRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DeleteCredentialRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DeleteCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCredential' +type Proxy_DeleteCredential_Call struct { + *mock.Call +} + +// DeleteCredential is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.DeleteCredentialRequest +func (_e *Proxy_Expecter) DeleteCredential(ctx interface{}, req interface{}) *Proxy_DeleteCredential_Call { + return &Proxy_DeleteCredential_Call{Call: _e.mock.On("DeleteCredential", ctx, req)} +} + +func (_c *Proxy_DeleteCredential_Call) Run(run func(ctx context.Context, req *milvuspb.DeleteCredentialRequest)) *Proxy_DeleteCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DeleteCredentialRequest)) + }) + return _c +} + +func (_c *Proxy_DeleteCredential_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_DeleteCredential_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DescribeCollection provides a mock function with given fields: ctx, request +func (_m *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.DescribeCollectionResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest) *milvuspb.DescribeCollectionResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeCollectionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DescribeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeCollection' +type Proxy_DescribeCollection_Call struct { + *mock.Call +} + +// DescribeCollection is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DescribeCollectionRequest +func (_e *Proxy_Expecter) DescribeCollection(ctx interface{}, request interface{}) *Proxy_DescribeCollection_Call { + return &Proxy_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", ctx, request)} +} + +func (_c *Proxy_DescribeCollection_Call) Run(run func(ctx context.Context, request *milvuspb.DescribeCollectionRequest)) *Proxy_DescribeCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DescribeCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_DescribeCollection_Call) Return(_a0 *milvuspb.DescribeCollectionResponse, _a1 error) *Proxy_DescribeCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DescribeIndex provides a mock function with given fields: ctx, request +func (_m *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.DescribeIndexResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeIndexRequest) *milvuspb.DescribeIndexResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeIndexResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeIndexRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DescribeIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeIndex' +type Proxy_DescribeIndex_Call struct { + *mock.Call +} + +// DescribeIndex is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DescribeIndexRequest +func (_e *Proxy_Expecter) DescribeIndex(ctx interface{}, request interface{}) *Proxy_DescribeIndex_Call { + return &Proxy_DescribeIndex_Call{Call: _e.mock.On("DescribeIndex", ctx, request)} +} + +func (_c *Proxy_DescribeIndex_Call) Run(run func(ctx context.Context, request *milvuspb.DescribeIndexRequest)) *Proxy_DescribeIndex_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DescribeIndexRequest)) + }) + return _c +} + +func (_c *Proxy_DescribeIndex_Call) Return(_a0 *milvuspb.DescribeIndexResponse, _a1 error) *Proxy_DescribeIndex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DescribeResourceGroup provides a mock function with given fields: ctx, req +func (_m *Proxy) DescribeResourceGroup(ctx context.Context, req *milvuspb.DescribeResourceGroupRequest) (*milvuspb.DescribeResourceGroupResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.DescribeResourceGroupResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeResourceGroupRequest) *milvuspb.DescribeResourceGroupResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeResourceGroupResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeResourceGroupRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DescribeResourceGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeResourceGroup' +type Proxy_DescribeResourceGroup_Call struct { + *mock.Call +} + +// DescribeResourceGroup is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.DescribeResourceGroupRequest +func (_e *Proxy_Expecter) DescribeResourceGroup(ctx interface{}, req interface{}) *Proxy_DescribeResourceGroup_Call { + return &Proxy_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", ctx, req)} +} + +func (_c *Proxy_DescribeResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.DescribeResourceGroupRequest)) *Proxy_DescribeResourceGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DescribeResourceGroupRequest)) + }) + return _c +} + +func (_c *Proxy_DescribeResourceGroup_Call) Return(_a0 *milvuspb.DescribeResourceGroupResponse, _a1 error) *Proxy_DescribeResourceGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DropAlias provides a mock function with given fields: ctx, request +func (_m *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropAliasRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropAliasRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DropAlias_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropAlias' +type Proxy_DropAlias_Call struct { + *mock.Call +} + +// DropAlias is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DropAliasRequest +func (_e *Proxy_Expecter) DropAlias(ctx interface{}, request interface{}) *Proxy_DropAlias_Call { + return &Proxy_DropAlias_Call{Call: _e.mock.On("DropAlias", ctx, request)} +} + +func (_c *Proxy_DropAlias_Call) Run(run func(ctx context.Context, request *milvuspb.DropAliasRequest)) *Proxy_DropAlias_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DropAliasRequest)) + }) + return _c +} + +func (_c *Proxy_DropAlias_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_DropAlias_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DropCollection provides a mock function with given fields: ctx, request +func (_m *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropCollectionRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropCollectionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DropCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropCollection' +type Proxy_DropCollection_Call struct { + *mock.Call +} + +// DropCollection is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DropCollectionRequest +func (_e *Proxy_Expecter) DropCollection(ctx interface{}, request interface{}) *Proxy_DropCollection_Call { + return &Proxy_DropCollection_Call{Call: _e.mock.On("DropCollection", ctx, request)} +} + +func (_c *Proxy_DropCollection_Call) Run(run func(ctx context.Context, request *milvuspb.DropCollectionRequest)) *Proxy_DropCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DropCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_DropCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_DropCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DropIndex provides a mock function with given fields: ctx, request +func (_m *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropIndexRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropIndexRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DropIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropIndex' +type Proxy_DropIndex_Call struct { + *mock.Call +} + +// DropIndex is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DropIndexRequest +func (_e *Proxy_Expecter) DropIndex(ctx interface{}, request interface{}) *Proxy_DropIndex_Call { + return &Proxy_DropIndex_Call{Call: _e.mock.On("DropIndex", ctx, request)} +} + +func (_c *Proxy_DropIndex_Call) Run(run func(ctx context.Context, request *milvuspb.DropIndexRequest)) *Proxy_DropIndex_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DropIndexRequest)) + }) + return _c +} + +func (_c *Proxy_DropIndex_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_DropIndex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DropPartition provides a mock function with given fields: ctx, request +func (_m *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropPartitionRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropPartitionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DropPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropPartition' +type Proxy_DropPartition_Call struct { + *mock.Call +} + +// DropPartition is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DropPartitionRequest +func (_e *Proxy_Expecter) DropPartition(ctx interface{}, request interface{}) *Proxy_DropPartition_Call { + return &Proxy_DropPartition_Call{Call: _e.mock.On("DropPartition", ctx, request)} +} + +func (_c *Proxy_DropPartition_Call) Run(run func(ctx context.Context, request *milvuspb.DropPartitionRequest)) *Proxy_DropPartition_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DropPartitionRequest)) + }) + return _c +} + +func (_c *Proxy_DropPartition_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_DropPartition_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DropResourceGroup provides a mock function with given fields: ctx, req +func (_m *Proxy) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropResourceGroupRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropResourceGroupRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DropResourceGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropResourceGroup' +type Proxy_DropResourceGroup_Call struct { + *mock.Call +} + +// DropResourceGroup is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.DropResourceGroupRequest +func (_e *Proxy_Expecter) DropResourceGroup(ctx interface{}, req interface{}) *Proxy_DropResourceGroup_Call { + return &Proxy_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", ctx, req)} +} + +func (_c *Proxy_DropResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.DropResourceGroupRequest)) *Proxy_DropResourceGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DropResourceGroupRequest)) + }) + return _c +} + +func (_c *Proxy_DropResourceGroup_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_DropResourceGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// DropRole provides a mock function with given fields: ctx, req +func (_m *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropRoleRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropRoleRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_DropRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropRole' +type Proxy_DropRole_Call struct { + *mock.Call +} + +// DropRole is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.DropRoleRequest +func (_e *Proxy_Expecter) DropRole(ctx interface{}, req interface{}) *Proxy_DropRole_Call { + return &Proxy_DropRole_Call{Call: _e.mock.On("DropRole", ctx, req)} +} + +func (_c *Proxy_DropRole_Call) Run(run func(ctx context.Context, req *milvuspb.DropRoleRequest)) *Proxy_DropRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DropRoleRequest)) + }) + return _c +} + +func (_c *Proxy_DropRole_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_DropRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Dummy provides a mock function with given fields: ctx, request +func (_m *Proxy) Dummy(ctx context.Context, request *milvuspb.DummyRequest) (*milvuspb.DummyResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.DummyResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DummyRequest) *milvuspb.DummyResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DummyResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DummyRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Dummy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Dummy' +type Proxy_Dummy_Call struct { + *mock.Call +} + +// Dummy is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.DummyRequest +func (_e *Proxy_Expecter) Dummy(ctx interface{}, request interface{}) *Proxy_Dummy_Call { + return &Proxy_Dummy_Call{Call: _e.mock.On("Dummy", ctx, request)} +} + +func (_c *Proxy_Dummy_Call) Run(run func(ctx context.Context, request *milvuspb.DummyRequest)) *Proxy_Dummy_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DummyRequest)) + }) + return _c +} + +func (_c *Proxy_Dummy_Call) Return(_a0 *milvuspb.DummyResponse, _a1 error) *Proxy_Dummy_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Flush provides a mock function with given fields: ctx, request +func (_m *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.FlushResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.FlushRequest) *milvuspb.FlushResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.FlushResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.FlushRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Flush_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Flush' +type Proxy_Flush_Call struct { + *mock.Call +} + +// Flush is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.FlushRequest +func (_e *Proxy_Expecter) Flush(ctx interface{}, request interface{}) *Proxy_Flush_Call { + return &Proxy_Flush_Call{Call: _e.mock.On("Flush", ctx, request)} +} + +func (_c *Proxy_Flush_Call) Run(run func(ctx context.Context, request *milvuspb.FlushRequest)) *Proxy_Flush_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.FlushRequest)) + }) + return _c +} + +func (_c *Proxy_Flush_Call) Return(_a0 *milvuspb.FlushResponse, _a1 error) *Proxy_Flush_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// FlushAll provides a mock function with given fields: ctx, request +func (_m *Proxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllRequest) (*milvuspb.FlushAllResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.FlushAllResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.FlushAllRequest) *milvuspb.FlushAllResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.FlushAllResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.FlushAllRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_FlushAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FlushAll' +type Proxy_FlushAll_Call struct { + *mock.Call +} + +// FlushAll is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.FlushAllRequest +func (_e *Proxy_Expecter) FlushAll(ctx interface{}, request interface{}) *Proxy_FlushAll_Call { + return &Proxy_FlushAll_Call{Call: _e.mock.On("FlushAll", ctx, request)} +} + +func (_c *Proxy_FlushAll_Call) Run(run func(ctx context.Context, request *milvuspb.FlushAllRequest)) *Proxy_FlushAll_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.FlushAllRequest)) + }) + return _c +} + +func (_c *Proxy_FlushAll_Call) Return(_a0 *milvuspb.FlushAllResponse, _a1 error) *Proxy_FlushAll_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetAddress provides a mock function with given fields: +func (_m *Proxy) GetAddress() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Proxy_GetAddress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAddress' +type Proxy_GetAddress_Call struct { + *mock.Call +} + +// GetAddress is a helper method to define mock.On call +func (_e *Proxy_Expecter) GetAddress() *Proxy_GetAddress_Call { + return &Proxy_GetAddress_Call{Call: _e.mock.On("GetAddress")} +} + +func (_c *Proxy_GetAddress_Call) Run(run func()) *Proxy_GetAddress_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Proxy_GetAddress_Call) Return(_a0 string) *Proxy_GetAddress_Call { + _c.Call.Return(_a0) + return _c +} + +// GetCollectionStatistics provides a mock function with given fields: ctx, request +func (_m *Proxy) GetCollectionStatistics(ctx context.Context, request *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetCollectionStatisticsResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCollectionStatisticsRequest) *milvuspb.GetCollectionStatisticsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetCollectionStatisticsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCollectionStatisticsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetCollectionStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionStatistics' +type Proxy_GetCollectionStatistics_Call struct { + *mock.Call +} + +// GetCollectionStatistics is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetCollectionStatisticsRequest +func (_e *Proxy_Expecter) GetCollectionStatistics(ctx interface{}, request interface{}) *Proxy_GetCollectionStatistics_Call { + return &Proxy_GetCollectionStatistics_Call{Call: _e.mock.On("GetCollectionStatistics", ctx, request)} +} + +func (_c *Proxy_GetCollectionStatistics_Call) Run(run func(ctx context.Context, request *milvuspb.GetCollectionStatisticsRequest)) *Proxy_GetCollectionStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetCollectionStatisticsRequest)) + }) + return _c +} + +func (_c *Proxy_GetCollectionStatistics_Call) Return(_a0 *milvuspb.GetCollectionStatisticsResponse, _a1 error) *Proxy_GetCollectionStatistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetCompactionState provides a mock function with given fields: ctx, req +func (_m *Proxy) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.GetCompactionStateResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionStateRequest) *milvuspb.GetCompactionStateResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetCompactionStateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionStateRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetCompactionState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompactionState' +type Proxy_GetCompactionState_Call struct { + *mock.Call +} + +// GetCompactionState is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.GetCompactionStateRequest +func (_e *Proxy_Expecter) GetCompactionState(ctx interface{}, req interface{}) *Proxy_GetCompactionState_Call { + return &Proxy_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", ctx, req)} +} + +func (_c *Proxy_GetCompactionState_Call) Run(run func(ctx context.Context, req *milvuspb.GetCompactionStateRequest)) *Proxy_GetCompactionState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionStateRequest)) + }) + return _c +} + +func (_c *Proxy_GetCompactionState_Call) Return(_a0 *milvuspb.GetCompactionStateResponse, _a1 error) *Proxy_GetCompactionState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetCompactionStateWithPlans provides a mock function with given fields: ctx, req +func (_m *Proxy) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.GetCompactionPlansResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionPlansRequest) *milvuspb.GetCompactionPlansResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetCompactionPlansResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionPlansRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetCompactionStateWithPlans_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompactionStateWithPlans' +type Proxy_GetCompactionStateWithPlans_Call struct { + *mock.Call +} + +// GetCompactionStateWithPlans is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.GetCompactionPlansRequest +func (_e *Proxy_Expecter) GetCompactionStateWithPlans(ctx interface{}, req interface{}) *Proxy_GetCompactionStateWithPlans_Call { + return &Proxy_GetCompactionStateWithPlans_Call{Call: _e.mock.On("GetCompactionStateWithPlans", ctx, req)} +} + +func (_c *Proxy_GetCompactionStateWithPlans_Call) Run(run func(ctx context.Context, req *milvuspb.GetCompactionPlansRequest)) *Proxy_GetCompactionStateWithPlans_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionPlansRequest)) + }) + return _c +} + +func (_c *Proxy_GetCompactionStateWithPlans_Call) Return(_a0 *milvuspb.GetCompactionPlansResponse, _a1 error) *Proxy_GetCompactionStateWithPlans_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx +func (_m *Proxy) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { + ret := _m.Called(ctx) + + var r0 *milvuspb.ComponentStates + if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type Proxy_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +func (_e *Proxy_Expecter) GetComponentStates(ctx interface{}) *Proxy_GetComponentStates_Call { + return &Proxy_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +} + +func (_c *Proxy_GetComponentStates_Call) Run(run func(ctx context.Context)) *Proxy_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Proxy_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *Proxy_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetDdChannel provides a mock function with given fields: ctx, request +func (_m *Proxy) GetDdChannel(ctx context.Context, request *internalpb.GetDdChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.StringResponse + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetDdChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetDdChannelRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetDdChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDdChannel' +type Proxy_GetDdChannel_Call struct { + *mock.Call +} + +// GetDdChannel is a helper method to define mock.On call +// - ctx context.Context +// - request *internalpb.GetDdChannelRequest +func (_e *Proxy_Expecter) GetDdChannel(ctx interface{}, request interface{}) *Proxy_GetDdChannel_Call { + return &Proxy_GetDdChannel_Call{Call: _e.mock.On("GetDdChannel", ctx, request)} +} + +func (_c *Proxy_GetDdChannel_Call) Run(run func(ctx context.Context, request *internalpb.GetDdChannelRequest)) *Proxy_GetDdChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*internalpb.GetDdChannelRequest)) + }) + return _c +} + +func (_c *Proxy_GetDdChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *Proxy_GetDdChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetFlushAllState provides a mock function with given fields: ctx, req +func (_m *Proxy) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.GetFlushAllStateResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushAllStateRequest) *milvuspb.GetFlushAllStateResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetFlushAllStateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetFlushAllStateRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetFlushAllState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetFlushAllState' +type Proxy_GetFlushAllState_Call struct { + *mock.Call +} + +// GetFlushAllState is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.GetFlushAllStateRequest +func (_e *Proxy_Expecter) GetFlushAllState(ctx interface{}, req interface{}) *Proxy_GetFlushAllState_Call { + return &Proxy_GetFlushAllState_Call{Call: _e.mock.On("GetFlushAllState", ctx, req)} +} + +func (_c *Proxy_GetFlushAllState_Call) Run(run func(ctx context.Context, req *milvuspb.GetFlushAllStateRequest)) *Proxy_GetFlushAllState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetFlushAllStateRequest)) + }) + return _c +} + +func (_c *Proxy_GetFlushAllState_Call) Return(_a0 *milvuspb.GetFlushAllStateResponse, _a1 error) *Proxy_GetFlushAllState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetFlushState provides a mock function with given fields: ctx, req +func (_m *Proxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.GetFlushStateResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushStateRequest) *milvuspb.GetFlushStateResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetFlushStateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetFlushStateRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetFlushState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetFlushState' +type Proxy_GetFlushState_Call struct { + *mock.Call +} + +// GetFlushState is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.GetFlushStateRequest +func (_e *Proxy_Expecter) GetFlushState(ctx interface{}, req interface{}) *Proxy_GetFlushState_Call { + return &Proxy_GetFlushState_Call{Call: _e.mock.On("GetFlushState", ctx, req)} +} + +func (_c *Proxy_GetFlushState_Call) Run(run func(ctx context.Context, req *milvuspb.GetFlushStateRequest)) *Proxy_GetFlushState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetFlushStateRequest)) + }) + return _c +} + +func (_c *Proxy_GetFlushState_Call) Return(_a0 *milvuspb.GetFlushStateResponse, _a1 error) *Proxy_GetFlushState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetImportState provides a mock function with given fields: ctx, req +func (_m *Proxy) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.GetImportStateResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetImportStateRequest) *milvuspb.GetImportStateResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetImportStateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetImportStateRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetImportState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetImportState' +type Proxy_GetImportState_Call struct { + *mock.Call +} + +// GetImportState is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.GetImportStateRequest +func (_e *Proxy_Expecter) GetImportState(ctx interface{}, req interface{}) *Proxy_GetImportState_Call { + return &Proxy_GetImportState_Call{Call: _e.mock.On("GetImportState", ctx, req)} +} + +func (_c *Proxy_GetImportState_Call) Run(run func(ctx context.Context, req *milvuspb.GetImportStateRequest)) *Proxy_GetImportState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetImportStateRequest)) + }) + return _c +} + +func (_c *Proxy_GetImportState_Call) Return(_a0 *milvuspb.GetImportStateResponse, _a1 error) *Proxy_GetImportState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetIndexBuildProgress provides a mock function with given fields: ctx, request +func (_m *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb.GetIndexBuildProgressRequest) (*milvuspb.GetIndexBuildProgressResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetIndexBuildProgressResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexBuildProgressRequest) *milvuspb.GetIndexBuildProgressResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetIndexBuildProgressResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetIndexBuildProgressRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetIndexBuildProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexBuildProgress' +type Proxy_GetIndexBuildProgress_Call struct { + *mock.Call +} + +// GetIndexBuildProgress is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetIndexBuildProgressRequest +func (_e *Proxy_Expecter) GetIndexBuildProgress(ctx interface{}, request interface{}) *Proxy_GetIndexBuildProgress_Call { + return &Proxy_GetIndexBuildProgress_Call{Call: _e.mock.On("GetIndexBuildProgress", ctx, request)} +} + +func (_c *Proxy_GetIndexBuildProgress_Call) Run(run func(ctx context.Context, request *milvuspb.GetIndexBuildProgressRequest)) *Proxy_GetIndexBuildProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetIndexBuildProgressRequest)) + }) + return _c +} + +func (_c *Proxy_GetIndexBuildProgress_Call) Return(_a0 *milvuspb.GetIndexBuildProgressResponse, _a1 error) *Proxy_GetIndexBuildProgress_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetIndexState provides a mock function with given fields: ctx, request +func (_m *Proxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndexStateRequest) (*milvuspb.GetIndexStateResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetIndexStateResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexStateRequest) *milvuspb.GetIndexStateResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetIndexStateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetIndexStateRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetIndexState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexState' +type Proxy_GetIndexState_Call struct { + *mock.Call +} + +// GetIndexState is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetIndexStateRequest +func (_e *Proxy_Expecter) GetIndexState(ctx interface{}, request interface{}) *Proxy_GetIndexState_Call { + return &Proxy_GetIndexState_Call{Call: _e.mock.On("GetIndexState", ctx, request)} +} + +func (_c *Proxy_GetIndexState_Call) Run(run func(ctx context.Context, request *milvuspb.GetIndexStateRequest)) *Proxy_GetIndexState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetIndexStateRequest)) + }) + return _c +} + +func (_c *Proxy_GetIndexState_Call) Return(_a0 *milvuspb.GetIndexStateResponse, _a1 error) *Proxy_GetIndexState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetLoadState provides a mock function with given fields: ctx, request +func (_m *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetLoadStateResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetLoadStateRequest) *milvuspb.GetLoadStateResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetLoadStateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetLoadStateRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetLoadState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLoadState' +type Proxy_GetLoadState_Call struct { + *mock.Call +} + +// GetLoadState is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetLoadStateRequest +func (_e *Proxy_Expecter) GetLoadState(ctx interface{}, request interface{}) *Proxy_GetLoadState_Call { + return &Proxy_GetLoadState_Call{Call: _e.mock.On("GetLoadState", ctx, request)} +} + +func (_c *Proxy_GetLoadState_Call) Run(run func(ctx context.Context, request *milvuspb.GetLoadStateRequest)) *Proxy_GetLoadState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetLoadStateRequest)) + }) + return _c +} + +func (_c *Proxy_GetLoadState_Call) Return(_a0 *milvuspb.GetLoadStateResponse, _a1 error) *Proxy_GetLoadState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetLoadingProgress provides a mock function with given fields: ctx, request +func (_m *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetLoadingProgressResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetLoadingProgressRequest) *milvuspb.GetLoadingProgressResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetLoadingProgressResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetLoadingProgressRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetLoadingProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLoadingProgress' +type Proxy_GetLoadingProgress_Call struct { + *mock.Call +} + +// GetLoadingProgress is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetLoadingProgressRequest +func (_e *Proxy_Expecter) GetLoadingProgress(ctx interface{}, request interface{}) *Proxy_GetLoadingProgress_Call { + return &Proxy_GetLoadingProgress_Call{Call: _e.mock.On("GetLoadingProgress", ctx, request)} +} + +func (_c *Proxy_GetLoadingProgress_Call) Run(run func(ctx context.Context, request *milvuspb.GetLoadingProgressRequest)) *Proxy_GetLoadingProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetLoadingProgressRequest)) + }) + return _c +} + +func (_c *Proxy_GetLoadingProgress_Call) Return(_a0 *milvuspb.GetLoadingProgressResponse, _a1 error) *Proxy_GetLoadingProgress_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetMetrics provides a mock function with given fields: ctx, request +func (_m *Proxy) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetMetricsResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMetrics' +type Proxy_GetMetrics_Call struct { + *mock.Call +} + +// GetMetrics is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetMetricsRequest +func (_e *Proxy_Expecter) GetMetrics(ctx interface{}, request interface{}) *Proxy_GetMetrics_Call { + return &Proxy_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, request)} +} + +func (_c *Proxy_GetMetrics_Call) Run(run func(ctx context.Context, request *milvuspb.GetMetricsRequest)) *Proxy_GetMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) + }) + return _c +} + +func (_c *Proxy_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *Proxy_GetMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetPartitionStatistics provides a mock function with given fields: ctx, request +func (_m *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetPartitionStatisticsResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetPartitionStatisticsRequest) *milvuspb.GetPartitionStatisticsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetPartitionStatisticsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetPartitionStatisticsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetPartitionStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitionStatistics' +type Proxy_GetPartitionStatistics_Call struct { + *mock.Call +} + +// GetPartitionStatistics is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetPartitionStatisticsRequest +func (_e *Proxy_Expecter) GetPartitionStatistics(ctx interface{}, request interface{}) *Proxy_GetPartitionStatistics_Call { + return &Proxy_GetPartitionStatistics_Call{Call: _e.mock.On("GetPartitionStatistics", ctx, request)} +} + +func (_c *Proxy_GetPartitionStatistics_Call) Run(run func(ctx context.Context, request *milvuspb.GetPartitionStatisticsRequest)) *Proxy_GetPartitionStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetPartitionStatisticsRequest)) + }) + return _c +} + +func (_c *Proxy_GetPartitionStatistics_Call) Return(_a0 *milvuspb.GetPartitionStatisticsResponse, _a1 error) *Proxy_GetPartitionStatistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetPersistentSegmentInfo provides a mock function with given fields: ctx, request +func (_m *Proxy) GetPersistentSegmentInfo(ctx context.Context, request *milvuspb.GetPersistentSegmentInfoRequest) (*milvuspb.GetPersistentSegmentInfoResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetPersistentSegmentInfoResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetPersistentSegmentInfoRequest) *milvuspb.GetPersistentSegmentInfoResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetPersistentSegmentInfoResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetPersistentSegmentInfoRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetPersistentSegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPersistentSegmentInfo' +type Proxy_GetPersistentSegmentInfo_Call struct { + *mock.Call +} + +// GetPersistentSegmentInfo is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetPersistentSegmentInfoRequest +func (_e *Proxy_Expecter) GetPersistentSegmentInfo(ctx interface{}, request interface{}) *Proxy_GetPersistentSegmentInfo_Call { + return &Proxy_GetPersistentSegmentInfo_Call{Call: _e.mock.On("GetPersistentSegmentInfo", ctx, request)} +} + +func (_c *Proxy_GetPersistentSegmentInfo_Call) Run(run func(ctx context.Context, request *milvuspb.GetPersistentSegmentInfoRequest)) *Proxy_GetPersistentSegmentInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetPersistentSegmentInfoRequest)) + }) + return _c +} + +func (_c *Proxy_GetPersistentSegmentInfo_Call) Return(_a0 *milvuspb.GetPersistentSegmentInfoResponse, _a1 error) *Proxy_GetPersistentSegmentInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetProxyMetrics provides a mock function with given fields: ctx, request +func (_m *Proxy) GetProxyMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetMetricsResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetProxyMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetProxyMetrics' +type Proxy_GetProxyMetrics_Call struct { + *mock.Call +} + +// GetProxyMetrics is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetMetricsRequest +func (_e *Proxy_Expecter) GetProxyMetrics(ctx interface{}, request interface{}) *Proxy_GetProxyMetrics_Call { + return &Proxy_GetProxyMetrics_Call{Call: _e.mock.On("GetProxyMetrics", ctx, request)} +} + +func (_c *Proxy_GetProxyMetrics_Call) Run(run func(ctx context.Context, request *milvuspb.GetMetricsRequest)) *Proxy_GetProxyMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) + }) + return _c +} + +func (_c *Proxy_GetProxyMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *Proxy_GetProxyMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetQuerySegmentInfo provides a mock function with given fields: ctx, request +func (_m *Proxy) GetQuerySegmentInfo(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.GetQuerySegmentInfoResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetQuerySegmentInfoRequest) *milvuspb.GetQuerySegmentInfoResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetQuerySegmentInfoResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetQuerySegmentInfoRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetQuerySegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQuerySegmentInfo' +type Proxy_GetQuerySegmentInfo_Call struct { + *mock.Call +} + +// GetQuerySegmentInfo is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.GetQuerySegmentInfoRequest +func (_e *Proxy_Expecter) GetQuerySegmentInfo(ctx interface{}, request interface{}) *Proxy_GetQuerySegmentInfo_Call { + return &Proxy_GetQuerySegmentInfo_Call{Call: _e.mock.On("GetQuerySegmentInfo", ctx, request)} +} + +func (_c *Proxy_GetQuerySegmentInfo_Call) Run(run func(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest)) *Proxy_GetQuerySegmentInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetQuerySegmentInfoRequest)) + }) + return _c +} + +func (_c *Proxy_GetQuerySegmentInfo_Call) Return(_a0 *milvuspb.GetQuerySegmentInfoResponse, _a1 error) *Proxy_GetQuerySegmentInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetRateLimiter provides a mock function with given fields: +func (_m *Proxy) GetRateLimiter() (types.Limiter, error) { + ret := _m.Called() + + var r0 types.Limiter + if rf, ok := ret.Get(0).(func() types.Limiter); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.Limiter) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetRateLimiter_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRateLimiter' +type Proxy_GetRateLimiter_Call struct { + *mock.Call +} + +// GetRateLimiter is a helper method to define mock.On call +func (_e *Proxy_Expecter) GetRateLimiter() *Proxy_GetRateLimiter_Call { + return &Proxy_GetRateLimiter_Call{Call: _e.mock.On("GetRateLimiter")} +} + +func (_c *Proxy_GetRateLimiter_Call) Run(run func()) *Proxy_GetRateLimiter_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Proxy_GetRateLimiter_Call) Return(_a0 types.Limiter, _a1 error) *Proxy_GetRateLimiter_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetReplicas provides a mock function with given fields: ctx, req +func (_m *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.GetReplicasResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetReplicasRequest) *milvuspb.GetReplicasResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetReplicasResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetReplicasRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetReplicas_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplicas' +type Proxy_GetReplicas_Call struct { + *mock.Call +} + +// GetReplicas is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.GetReplicasRequest +func (_e *Proxy_Expecter) GetReplicas(ctx interface{}, req interface{}) *Proxy_GetReplicas_Call { + return &Proxy_GetReplicas_Call{Call: _e.mock.On("GetReplicas", ctx, req)} +} + +func (_c *Proxy_GetReplicas_Call) Run(run func(ctx context.Context, req *milvuspb.GetReplicasRequest)) *Proxy_GetReplicas_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetReplicasRequest)) + }) + return _c +} + +func (_c *Proxy_GetReplicas_Call) Return(_a0 *milvuspb.GetReplicasResponse, _a1 error) *Proxy_GetReplicas_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx +func (_m *Proxy) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + ret := _m.Called(ctx) + + var r0 *milvuspb.StringResponse + if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type Proxy_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +func (_e *Proxy_Expecter) GetStatisticsChannel(ctx interface{}) *Proxy_GetStatisticsChannel_Call { + return &Proxy_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +} + +func (_c *Proxy_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *Proxy_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Proxy_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *Proxy_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// HasCollection provides a mock function with given fields: ctx, request +func (_m *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.BoolResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.BoolResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasCollectionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_HasCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasCollection' +type Proxy_HasCollection_Call struct { + *mock.Call +} + +// HasCollection is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.HasCollectionRequest +func (_e *Proxy_Expecter) HasCollection(ctx interface{}, request interface{}) *Proxy_HasCollection_Call { + return &Proxy_HasCollection_Call{Call: _e.mock.On("HasCollection", ctx, request)} +} + +func (_c *Proxy_HasCollection_Call) Run(run func(ctx context.Context, request *milvuspb.HasCollectionRequest)) *Proxy_HasCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.HasCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_HasCollection_Call) Return(_a0 *milvuspb.BoolResponse, _a1 error) *Proxy_HasCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// HasPartition provides a mock function with given fields: ctx, request +func (_m *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.BoolResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasPartitionRequest) *milvuspb.BoolResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.BoolResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasPartitionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_HasPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasPartition' +type Proxy_HasPartition_Call struct { + *mock.Call +} + +// HasPartition is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.HasPartitionRequest +func (_e *Proxy_Expecter) HasPartition(ctx interface{}, request interface{}) *Proxy_HasPartition_Call { + return &Proxy_HasPartition_Call{Call: _e.mock.On("HasPartition", ctx, request)} +} + +func (_c *Proxy_HasPartition_Call) Run(run func(ctx context.Context, request *milvuspb.HasPartitionRequest)) *Proxy_HasPartition_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.HasPartitionRequest)) + }) + return _c +} + +func (_c *Proxy_HasPartition_Call) Return(_a0 *milvuspb.BoolResponse, _a1 error) *Proxy_HasPartition_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Import provides a mock function with given fields: ctx, req +func (_m *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.ImportResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ImportRequest) *milvuspb.ImportResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ImportResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ImportRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Import_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Import' +type Proxy_Import_Call struct { + *mock.Call +} + +// Import is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.ImportRequest +func (_e *Proxy_Expecter) Import(ctx interface{}, req interface{}) *Proxy_Import_Call { + return &Proxy_Import_Call{Call: _e.mock.On("Import", ctx, req)} +} + +func (_c *Proxy_Import_Call) Run(run func(ctx context.Context, req *milvuspb.ImportRequest)) *Proxy_Import_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ImportRequest)) + }) + return _c +} + +func (_c *Proxy_Import_Call) Return(_a0 *milvuspb.ImportResponse, _a1 error) *Proxy_Import_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Init provides a mock function with given fields: +func (_m *Proxy) Init() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Proxy_Init_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Init' +type Proxy_Init_Call struct { + *mock.Call +} + +// Init is a helper method to define mock.On call +func (_e *Proxy_Expecter) Init() *Proxy_Init_Call { + return &Proxy_Init_Call{Call: _e.mock.On("Init")} +} + +func (_c *Proxy_Init_Call) Run(run func()) *Proxy_Init_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Proxy_Init_Call) Return(_a0 error) *Proxy_Init_Call { + _c.Call.Return(_a0) + return _c +} + +// Insert provides a mock function with given fields: ctx, request +func (_m *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.MutationResult + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.InsertRequest) *milvuspb.MutationResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.MutationResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.InsertRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Insert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Insert' +type Proxy_Insert_Call struct { + *mock.Call +} + +// Insert is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.InsertRequest +func (_e *Proxy_Expecter) Insert(ctx interface{}, request interface{}) *Proxy_Insert_Call { + return &Proxy_Insert_Call{Call: _e.mock.On("Insert", ctx, request)} +} + +func (_c *Proxy_Insert_Call) Run(run func(ctx context.Context, request *milvuspb.InsertRequest)) *Proxy_Insert_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.InsertRequest)) + }) + return _c +} + +func (_c *Proxy_Insert_Call) Return(_a0 *milvuspb.MutationResult, _a1 error) *Proxy_Insert_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// InvalidateCollectionMetaCache provides a mock function with given fields: ctx, request +func (_m *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_InvalidateCollectionMetaCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateCollectionMetaCache' +type Proxy_InvalidateCollectionMetaCache_Call struct { + *mock.Call +} + +// InvalidateCollectionMetaCache is a helper method to define mock.On call +// - ctx context.Context +// - request *proxypb.InvalidateCollMetaCacheRequest +func (_e *Proxy_Expecter) InvalidateCollectionMetaCache(ctx interface{}, request interface{}) *Proxy_InvalidateCollectionMetaCache_Call { + return &Proxy_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", ctx, request)} +} + +func (_c *Proxy_InvalidateCollectionMetaCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest)) *Proxy_InvalidateCollectionMetaCache_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.InvalidateCollMetaCacheRequest)) + }) + return _c +} + +func (_c *Proxy_InvalidateCollectionMetaCache_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_InvalidateCollectionMetaCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// InvalidateCredentialCache provides a mock function with given fields: ctx, request +func (_m *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_InvalidateCredentialCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateCredentialCache' +type Proxy_InvalidateCredentialCache_Call struct { + *mock.Call +} + +// InvalidateCredentialCache is a helper method to define mock.On call +// - ctx context.Context +// - request *proxypb.InvalidateCredCacheRequest +func (_e *Proxy_Expecter) InvalidateCredentialCache(ctx interface{}, request interface{}) *Proxy_InvalidateCredentialCache_Call { + return &Proxy_InvalidateCredentialCache_Call{Call: _e.mock.On("InvalidateCredentialCache", ctx, request)} +} + +func (_c *Proxy_InvalidateCredentialCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest)) *Proxy_InvalidateCredentialCache_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.InvalidateCredCacheRequest)) + }) + return _c +} + +func (_c *Proxy_InvalidateCredentialCache_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_InvalidateCredentialCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// ListCredUsers provides a mock function with given fields: ctx, req +func (_m *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.ListCredUsersResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListCredUsersRequest) *milvuspb.ListCredUsersResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListCredUsersResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListCredUsersRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ListCredUsers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCredUsers' +type Proxy_ListCredUsers_Call struct { + *mock.Call +} + +// ListCredUsers is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.ListCredUsersRequest +func (_e *Proxy_Expecter) ListCredUsers(ctx interface{}, req interface{}) *Proxy_ListCredUsers_Call { + return &Proxy_ListCredUsers_Call{Call: _e.mock.On("ListCredUsers", ctx, req)} +} + +func (_c *Proxy_ListCredUsers_Call) Run(run func(ctx context.Context, req *milvuspb.ListCredUsersRequest)) *Proxy_ListCredUsers_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ListCredUsersRequest)) + }) + return _c +} + +func (_c *Proxy_ListCredUsers_Call) Return(_a0 *milvuspb.ListCredUsersResponse, _a1 error) *Proxy_ListCredUsers_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// ListImportTasks provides a mock function with given fields: ctx, req +func (_m *Proxy) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.ListImportTasksResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListImportTasksRequest) *milvuspb.ListImportTasksResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListImportTasksResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListImportTasksRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ListImportTasks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListImportTasks' +type Proxy_ListImportTasks_Call struct { + *mock.Call +} + +// ListImportTasks is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.ListImportTasksRequest +func (_e *Proxy_Expecter) ListImportTasks(ctx interface{}, req interface{}) *Proxy_ListImportTasks_Call { + return &Proxy_ListImportTasks_Call{Call: _e.mock.On("ListImportTasks", ctx, req)} +} + +func (_c *Proxy_ListImportTasks_Call) Run(run func(ctx context.Context, req *milvuspb.ListImportTasksRequest)) *Proxy_ListImportTasks_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ListImportTasksRequest)) + }) + return _c +} + +func (_c *Proxy_ListImportTasks_Call) Return(_a0 *milvuspb.ListImportTasksResponse, _a1 error) *Proxy_ListImportTasks_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// ListResourceGroups provides a mock function with given fields: ctx, req +func (_m *Proxy) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.ListResourceGroupsResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListResourceGroupsRequest) *milvuspb.ListResourceGroupsResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListResourceGroupsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListResourceGroupsRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ListResourceGroups_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListResourceGroups' +type Proxy_ListResourceGroups_Call struct { + *mock.Call +} + +// ListResourceGroups is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.ListResourceGroupsRequest +func (_e *Proxy_Expecter) ListResourceGroups(ctx interface{}, req interface{}) *Proxy_ListResourceGroups_Call { + return &Proxy_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", ctx, req)} +} + +func (_c *Proxy_ListResourceGroups_Call) Run(run func(ctx context.Context, req *milvuspb.ListResourceGroupsRequest)) *Proxy_ListResourceGroups_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ListResourceGroupsRequest)) + }) + return _c +} + +func (_c *Proxy_ListResourceGroups_Call) Return(_a0 *milvuspb.ListResourceGroupsResponse, _a1 error) *Proxy_ListResourceGroups_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// LoadBalance provides a mock function with given fields: ctx, request +func (_m *Proxy) LoadBalance(ctx context.Context, request *milvuspb.LoadBalanceRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadBalanceRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.LoadBalanceRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_LoadBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadBalance' +type Proxy_LoadBalance_Call struct { + *mock.Call +} + +// LoadBalance is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.LoadBalanceRequest +func (_e *Proxy_Expecter) LoadBalance(ctx interface{}, request interface{}) *Proxy_LoadBalance_Call { + return &Proxy_LoadBalance_Call{Call: _e.mock.On("LoadBalance", ctx, request)} +} + +func (_c *Proxy_LoadBalance_Call) Run(run func(ctx context.Context, request *milvuspb.LoadBalanceRequest)) *Proxy_LoadBalance_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.LoadBalanceRequest)) + }) + return _c +} + +func (_c *Proxy_LoadBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_LoadBalance_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// LoadCollection provides a mock function with given fields: ctx, request +func (_m *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadCollectionRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.LoadCollectionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_LoadCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadCollection' +type Proxy_LoadCollection_Call struct { + *mock.Call +} + +// LoadCollection is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.LoadCollectionRequest +func (_e *Proxy_Expecter) LoadCollection(ctx interface{}, request interface{}) *Proxy_LoadCollection_Call { + return &Proxy_LoadCollection_Call{Call: _e.mock.On("LoadCollection", ctx, request)} +} + +func (_c *Proxy_LoadCollection_Call) Run(run func(ctx context.Context, request *milvuspb.LoadCollectionRequest)) *Proxy_LoadCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.LoadCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_LoadCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_LoadCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// LoadPartitions provides a mock function with given fields: ctx, request +func (_m *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadPartitionsRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.LoadPartitionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_LoadPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadPartitions' +type Proxy_LoadPartitions_Call struct { + *mock.Call +} + +// LoadPartitions is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.LoadPartitionsRequest +func (_e *Proxy_Expecter) LoadPartitions(ctx interface{}, request interface{}) *Proxy_LoadPartitions_Call { + return &Proxy_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, request)} +} + +func (_c *Proxy_LoadPartitions_Call) Run(run func(ctx context.Context, request *milvuspb.LoadPartitionsRequest)) *Proxy_LoadPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.LoadPartitionsRequest)) + }) + return _c +} + +func (_c *Proxy_LoadPartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_LoadPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// ManualCompaction provides a mock function with given fields: ctx, req +func (_m *Proxy) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.ManualCompactionResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ManualCompactionRequest) *milvuspb.ManualCompactionResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ManualCompactionResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ManualCompactionRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ManualCompaction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ManualCompaction' +type Proxy_ManualCompaction_Call struct { + *mock.Call +} + +// ManualCompaction is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.ManualCompactionRequest +func (_e *Proxy_Expecter) ManualCompaction(ctx interface{}, req interface{}) *Proxy_ManualCompaction_Call { + return &Proxy_ManualCompaction_Call{Call: _e.mock.On("ManualCompaction", ctx, req)} +} + +func (_c *Proxy_ManualCompaction_Call) Run(run func(ctx context.Context, req *milvuspb.ManualCompactionRequest)) *Proxy_ManualCompaction_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ManualCompactionRequest)) + }) + return _c +} + +func (_c *Proxy_ManualCompaction_Call) Return(_a0 *milvuspb.ManualCompactionResponse, _a1 error) *Proxy_ManualCompaction_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// OperatePrivilege provides a mock function with given fields: ctx, req +func (_m *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperatePrivilegeRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_OperatePrivilege_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OperatePrivilege' +type Proxy_OperatePrivilege_Call struct { + *mock.Call +} + +// OperatePrivilege is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.OperatePrivilegeRequest +func (_e *Proxy_Expecter) OperatePrivilege(ctx interface{}, req interface{}) *Proxy_OperatePrivilege_Call { + return &Proxy_OperatePrivilege_Call{Call: _e.mock.On("OperatePrivilege", ctx, req)} +} + +func (_c *Proxy_OperatePrivilege_Call) Run(run func(ctx context.Context, req *milvuspb.OperatePrivilegeRequest)) *Proxy_OperatePrivilege_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.OperatePrivilegeRequest)) + }) + return _c +} + +func (_c *Proxy_OperatePrivilege_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_OperatePrivilege_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// OperateUserRole provides a mock function with given fields: ctx, req +func (_m *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperateUserRoleRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperateUserRoleRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_OperateUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OperateUserRole' +type Proxy_OperateUserRole_Call struct { + *mock.Call +} + +// OperateUserRole is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.OperateUserRoleRequest +func (_e *Proxy_Expecter) OperateUserRole(ctx interface{}, req interface{}) *Proxy_OperateUserRole_Call { + return &Proxy_OperateUserRole_Call{Call: _e.mock.On("OperateUserRole", ctx, req)} +} + +func (_c *Proxy_OperateUserRole_Call) Run(run func(ctx context.Context, req *milvuspb.OperateUserRoleRequest)) *Proxy_OperateUserRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.OperateUserRoleRequest)) + }) + return _c +} + +func (_c *Proxy_OperateUserRole_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_OperateUserRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Query provides a mock function with given fields: ctx, request +func (_m *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.QueryResults + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.QueryRequest) *milvuspb.QueryResults); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.QueryResults) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.QueryRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query' +type Proxy_Query_Call struct { + *mock.Call +} + +// Query is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.QueryRequest +func (_e *Proxy_Expecter) Query(ctx interface{}, request interface{}) *Proxy_Query_Call { + return &Proxy_Query_Call{Call: _e.mock.On("Query", ctx, request)} +} + +func (_c *Proxy_Query_Call) Run(run func(ctx context.Context, request *milvuspb.QueryRequest)) *Proxy_Query_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.QueryRequest)) + }) + return _c +} + +func (_c *Proxy_Query_Call) Return(_a0 *milvuspb.QueryResults, _a1 error) *Proxy_Query_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// RefreshPolicyInfoCache provides a mock function with given fields: ctx, req +func (_m *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_RefreshPolicyInfoCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfoCache' +type Proxy_RefreshPolicyInfoCache_Call struct { + *mock.Call +} + +// RefreshPolicyInfoCache is a helper method to define mock.On call +// - ctx context.Context +// - req *proxypb.RefreshPolicyInfoCacheRequest +func (_e *Proxy_Expecter) RefreshPolicyInfoCache(ctx interface{}, req interface{}) *Proxy_RefreshPolicyInfoCache_Call { + return &Proxy_RefreshPolicyInfoCache_Call{Call: _e.mock.On("RefreshPolicyInfoCache", ctx, req)} +} + +func (_c *Proxy_RefreshPolicyInfoCache_Call) Run(run func(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest)) *Proxy_RefreshPolicyInfoCache_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.RefreshPolicyInfoCacheRequest)) + }) + return _c +} + +func (_c *Proxy_RefreshPolicyInfoCache_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_RefreshPolicyInfoCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Register provides a mock function with given fields: +func (_m *Proxy) Register() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Proxy_Register_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Register' +type Proxy_Register_Call struct { + *mock.Call +} + +// Register is a helper method to define mock.On call +func (_e *Proxy_Expecter) Register() *Proxy_Register_Call { + return &Proxy_Register_Call{Call: _e.mock.On("Register")} +} + +func (_c *Proxy_Register_Call) Run(run func()) *Proxy_Register_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Proxy_Register_Call) Return(_a0 error) *Proxy_Register_Call { + _c.Call.Return(_a0) + return _c +} + +// RegisterLink provides a mock function with given fields: ctx, request +func (_m *Proxy) RegisterLink(ctx context.Context, request *milvuspb.RegisterLinkRequest) (*milvuspb.RegisterLinkResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.RegisterLinkResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RegisterLinkRequest) *milvuspb.RegisterLinkResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.RegisterLinkResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.RegisterLinkRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_RegisterLink_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RegisterLink' +type Proxy_RegisterLink_Call struct { + *mock.Call +} + +// RegisterLink is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.RegisterLinkRequest +func (_e *Proxy_Expecter) RegisterLink(ctx interface{}, request interface{}) *Proxy_RegisterLink_Call { + return &Proxy_RegisterLink_Call{Call: _e.mock.On("RegisterLink", ctx, request)} +} + +func (_c *Proxy_RegisterLink_Call) Run(run func(ctx context.Context, request *milvuspb.RegisterLinkRequest)) *Proxy_RegisterLink_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.RegisterLinkRequest)) + }) + return _c +} + +func (_c *Proxy_RegisterLink_Call) Return(_a0 *milvuspb.RegisterLinkResponse, _a1 error) *Proxy_RegisterLink_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// ReleaseCollection provides a mock function with given fields: ctx, request +func (_m *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReleaseCollectionRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ReleaseCollectionRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ReleaseCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseCollection' +type Proxy_ReleaseCollection_Call struct { + *mock.Call +} + +// ReleaseCollection is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.ReleaseCollectionRequest +func (_e *Proxy_Expecter) ReleaseCollection(ctx interface{}, request interface{}) *Proxy_ReleaseCollection_Call { + return &Proxy_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, request)} +} + +func (_c *Proxy_ReleaseCollection_Call) Run(run func(ctx context.Context, request *milvuspb.ReleaseCollectionRequest)) *Proxy_ReleaseCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ReleaseCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_ReleaseCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_ReleaseCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// ReleasePartitions provides a mock function with given fields: ctx, request +func (_m *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReleasePartitionsRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ReleasePartitionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ReleasePartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleasePartitions' +type Proxy_ReleasePartitions_Call struct { + *mock.Call +} + +// ReleasePartitions is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.ReleasePartitionsRequest +func (_e *Proxy_Expecter) ReleasePartitions(ctx interface{}, request interface{}) *Proxy_ReleasePartitions_Call { + return &Proxy_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, request)} +} + +func (_c *Proxy_ReleasePartitions_Call) Run(run func(ctx context.Context, request *milvuspb.ReleasePartitionsRequest)) *Proxy_ReleasePartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ReleasePartitionsRequest)) + }) + return _c +} + +func (_c *Proxy_ReleasePartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_ReleasePartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// RenameCollection provides a mock function with given fields: ctx, req +func (_m *Proxy) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RenameCollectionRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.RenameCollectionRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_RenameCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RenameCollection' +type Proxy_RenameCollection_Call struct { + *mock.Call +} + +// RenameCollection is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.RenameCollectionRequest +func (_e *Proxy_Expecter) RenameCollection(ctx interface{}, req interface{}) *Proxy_RenameCollection_Call { + return &Proxy_RenameCollection_Call{Call: _e.mock.On("RenameCollection", ctx, req)} +} + +func (_c *Proxy_RenameCollection_Call) Run(run func(ctx context.Context, req *milvuspb.RenameCollectionRequest)) *Proxy_RenameCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.RenameCollectionRequest)) + }) + return _c +} + +func (_c *Proxy_RenameCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_RenameCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Search provides a mock function with given fields: ctx, request +func (_m *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.SearchResults + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SearchRequest) *milvuspb.SearchResults); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SearchResults) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SearchRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search' +type Proxy_Search_Call struct { + *mock.Call +} + +// Search is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.SearchRequest +func (_e *Proxy_Expecter) Search(ctx interface{}, request interface{}) *Proxy_Search_Call { + return &Proxy_Search_Call{Call: _e.mock.On("Search", ctx, request)} +} + +func (_c *Proxy_Search_Call) Run(run func(ctx context.Context, request *milvuspb.SearchRequest)) *Proxy_Search_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.SearchRequest)) + }) + return _c +} + +func (_c *Proxy_Search_Call) Return(_a0 *milvuspb.SearchResults, _a1 error) *Proxy_Search_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// SelectGrant provides a mock function with given fields: ctx, req +func (_m *Proxy) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.SelectGrantResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectGrantRequest) *milvuspb.SelectGrantResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SelectGrantResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectGrantRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_SelectGrant_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SelectGrant' +type Proxy_SelectGrant_Call struct { + *mock.Call +} + +// SelectGrant is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.SelectGrantRequest +func (_e *Proxy_Expecter) SelectGrant(ctx interface{}, req interface{}) *Proxy_SelectGrant_Call { + return &Proxy_SelectGrant_Call{Call: _e.mock.On("SelectGrant", ctx, req)} +} + +func (_c *Proxy_SelectGrant_Call) Run(run func(ctx context.Context, req *milvuspb.SelectGrantRequest)) *Proxy_SelectGrant_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.SelectGrantRequest)) + }) + return _c +} + +func (_c *Proxy_SelectGrant_Call) Return(_a0 *milvuspb.SelectGrantResponse, _a1 error) *Proxy_SelectGrant_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// SelectRole provides a mock function with given fields: ctx, req +func (_m *Proxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.SelectRoleResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectRoleRequest) *milvuspb.SelectRoleResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SelectRoleResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectRoleRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_SelectRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SelectRole' +type Proxy_SelectRole_Call struct { + *mock.Call +} + +// SelectRole is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.SelectRoleRequest +func (_e *Proxy_Expecter) SelectRole(ctx interface{}, req interface{}) *Proxy_SelectRole_Call { + return &Proxy_SelectRole_Call{Call: _e.mock.On("SelectRole", ctx, req)} +} + +func (_c *Proxy_SelectRole_Call) Run(run func(ctx context.Context, req *milvuspb.SelectRoleRequest)) *Proxy_SelectRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.SelectRoleRequest)) + }) + return _c +} + +func (_c *Proxy_SelectRole_Call) Return(_a0 *milvuspb.SelectRoleResponse, _a1 error) *Proxy_SelectRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// SelectUser provides a mock function with given fields: ctx, req +func (_m *Proxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.SelectUserResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectUserRequest) *milvuspb.SelectUserResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SelectUserResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectUserRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_SelectUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SelectUser' +type Proxy_SelectUser_Call struct { + *mock.Call +} + +// SelectUser is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.SelectUserRequest +func (_e *Proxy_Expecter) SelectUser(ctx interface{}, req interface{}) *Proxy_SelectUser_Call { + return &Proxy_SelectUser_Call{Call: _e.mock.On("SelectUser", ctx, req)} +} + +func (_c *Proxy_SelectUser_Call) Run(run func(ctx context.Context, req *milvuspb.SelectUserRequest)) *Proxy_SelectUser_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.SelectUserRequest)) + }) + return _c +} + +func (_c *Proxy_SelectUser_Call) Return(_a0 *milvuspb.SelectUserResponse, _a1 error) *Proxy_SelectUser_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// SetAddress provides a mock function with given fields: address +func (_m *Proxy) SetAddress(address string) { + _m.Called(address) +} + +// Proxy_SetAddress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAddress' +type Proxy_SetAddress_Call struct { + *mock.Call +} + +// SetAddress is a helper method to define mock.On call +// - address string +func (_e *Proxy_Expecter) SetAddress(address interface{}) *Proxy_SetAddress_Call { + return &Proxy_SetAddress_Call{Call: _e.mock.On("SetAddress", address)} +} + +func (_c *Proxy_SetAddress_Call) Run(run func(address string)) *Proxy_SetAddress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Proxy_SetAddress_Call) Return() *Proxy_SetAddress_Call { + _c.Call.Return() + return _c +} + +// SetDataCoordClient provides a mock function with given fields: dataCoord +func (_m *Proxy) SetDataCoordClient(dataCoord types.DataCoord) { + _m.Called(dataCoord) +} + +// Proxy_SetDataCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDataCoordClient' +type Proxy_SetDataCoordClient_Call struct { + *mock.Call +} + +// SetDataCoordClient is a helper method to define mock.On call +// - dataCoord types.DataCoord +func (_e *Proxy_Expecter) SetDataCoordClient(dataCoord interface{}) *Proxy_SetDataCoordClient_Call { + return &Proxy_SetDataCoordClient_Call{Call: _e.mock.On("SetDataCoordClient", dataCoord)} +} + +func (_c *Proxy_SetDataCoordClient_Call) Run(run func(dataCoord types.DataCoord)) *Proxy_SetDataCoordClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.DataCoord)) + }) + return _c +} + +func (_c *Proxy_SetDataCoordClient_Call) Return() *Proxy_SetDataCoordClient_Call { + _c.Call.Return() + return _c +} + +// SetEtcdClient provides a mock function with given fields: etcdClient +func (_m *Proxy) SetEtcdClient(etcdClient *clientv3.Client) { + _m.Called(etcdClient) +} + +// Proxy_SetEtcdClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetEtcdClient' +type Proxy_SetEtcdClient_Call struct { + *mock.Call +} + +// SetEtcdClient is a helper method to define mock.On call +// - etcdClient *clientv3.Client +func (_e *Proxy_Expecter) SetEtcdClient(etcdClient interface{}) *Proxy_SetEtcdClient_Call { + return &Proxy_SetEtcdClient_Call{Call: _e.mock.On("SetEtcdClient", etcdClient)} +} + +func (_c *Proxy_SetEtcdClient_Call) Run(run func(etcdClient *clientv3.Client)) *Proxy_SetEtcdClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*clientv3.Client)) + }) + return _c +} + +func (_c *Proxy_SetEtcdClient_Call) Return() *Proxy_SetEtcdClient_Call { + _c.Call.Return() + return _c +} + +// SetQueryCoordClient provides a mock function with given fields: queryCoord +func (_m *Proxy) SetQueryCoordClient(queryCoord types.QueryCoord) { + _m.Called(queryCoord) +} + +// Proxy_SetQueryCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetQueryCoordClient' +type Proxy_SetQueryCoordClient_Call struct { + *mock.Call +} + +// SetQueryCoordClient is a helper method to define mock.On call +// - queryCoord types.QueryCoord +func (_e *Proxy_Expecter) SetQueryCoordClient(queryCoord interface{}) *Proxy_SetQueryCoordClient_Call { + return &Proxy_SetQueryCoordClient_Call{Call: _e.mock.On("SetQueryCoordClient", queryCoord)} +} + +func (_c *Proxy_SetQueryCoordClient_Call) Run(run func(queryCoord types.QueryCoord)) *Proxy_SetQueryCoordClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.QueryCoord)) + }) + return _c +} + +func (_c *Proxy_SetQueryCoordClient_Call) Return() *Proxy_SetQueryCoordClient_Call { + _c.Call.Return() + return _c +} + +// SetQueryNodeCreator provides a mock function with given fields: _a0 +func (_m *Proxy) SetQueryNodeCreator(_a0 func(context.Context, string) (types.QueryNode, error)) { + _m.Called(_a0) +} + +// Proxy_SetQueryNodeCreator_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetQueryNodeCreator' +type Proxy_SetQueryNodeCreator_Call struct { + *mock.Call +} + +// SetQueryNodeCreator is a helper method to define mock.On call +// - _a0 func(context.Context , string)(types.QueryNode , error) +func (_e *Proxy_Expecter) SetQueryNodeCreator(_a0 interface{}) *Proxy_SetQueryNodeCreator_Call { + return &Proxy_SetQueryNodeCreator_Call{Call: _e.mock.On("SetQueryNodeCreator", _a0)} +} + +func (_c *Proxy_SetQueryNodeCreator_Call) Run(run func(_a0 func(context.Context, string) (types.QueryNode, error))) *Proxy_SetQueryNodeCreator_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func(context.Context, string) (types.QueryNode, error))) + }) + return _c +} + +func (_c *Proxy_SetQueryNodeCreator_Call) Return() *Proxy_SetQueryNodeCreator_Call { + _c.Call.Return() + return _c +} + +// SetRates provides a mock function with given fields: ctx, req +func (_m *Proxy) SetRates(ctx context.Context, req *proxypb.SetRatesRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.SetRatesRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_SetRates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRates' +type Proxy_SetRates_Call struct { + *mock.Call +} + +// SetRates is a helper method to define mock.On call +// - ctx context.Context +// - req *proxypb.SetRatesRequest +func (_e *Proxy_Expecter) SetRates(ctx interface{}, req interface{}) *Proxy_SetRates_Call { + return &Proxy_SetRates_Call{Call: _e.mock.On("SetRates", ctx, req)} +} + +func (_c *Proxy_SetRates_Call) Run(run func(ctx context.Context, req *proxypb.SetRatesRequest)) *Proxy_SetRates_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.SetRatesRequest)) + }) + return _c +} + +func (_c *Proxy_SetRates_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_SetRates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// SetRootCoordClient provides a mock function with given fields: rootCoord +func (_m *Proxy) SetRootCoordClient(rootCoord types.RootCoord) { + _m.Called(rootCoord) +} + +// Proxy_SetRootCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRootCoordClient' +type Proxy_SetRootCoordClient_Call struct { + *mock.Call +} + +// SetRootCoordClient is a helper method to define mock.On call +// - rootCoord types.RootCoord +func (_e *Proxy_Expecter) SetRootCoordClient(rootCoord interface{}) *Proxy_SetRootCoordClient_Call { + return &Proxy_SetRootCoordClient_Call{Call: _e.mock.On("SetRootCoordClient", rootCoord)} +} + +func (_c *Proxy_SetRootCoordClient_Call) Run(run func(rootCoord types.RootCoord)) *Proxy_SetRootCoordClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.RootCoord)) + }) + return _c +} + +func (_c *Proxy_SetRootCoordClient_Call) Return() *Proxy_SetRootCoordClient_Call { + _c.Call.Return() + return _c +} + +// ShowCollections provides a mock function with given fields: ctx, request +func (_m *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.ShowCollectionsResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowCollectionsRequest) *milvuspb.ShowCollectionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ShowCollectionsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowCollectionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ShowCollections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowCollections' +type Proxy_ShowCollections_Call struct { + *mock.Call +} + +// ShowCollections is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.ShowCollectionsRequest +func (_e *Proxy_Expecter) ShowCollections(ctx interface{}, request interface{}) *Proxy_ShowCollections_Call { + return &Proxy_ShowCollections_Call{Call: _e.mock.On("ShowCollections", ctx, request)} +} + +func (_c *Proxy_ShowCollections_Call) Run(run func(ctx context.Context, request *milvuspb.ShowCollectionsRequest)) *Proxy_ShowCollections_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ShowCollectionsRequest)) + }) + return _c +} + +func (_c *Proxy_ShowCollections_Call) Return(_a0 *milvuspb.ShowCollectionsResponse, _a1 error) *Proxy_ShowCollections_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// ShowPartitions provides a mock function with given fields: ctx, request +func (_m *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.ShowPartitionsResponse + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest) *milvuspb.ShowPartitionsResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ShowPartitionsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowPartitionsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_ShowPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowPartitions' +type Proxy_ShowPartitions_Call struct { + *mock.Call +} + +// ShowPartitions is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.ShowPartitionsRequest +func (_e *Proxy_Expecter) ShowPartitions(ctx interface{}, request interface{}) *Proxy_ShowPartitions_Call { + return &Proxy_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, request)} +} + +func (_c *Proxy_ShowPartitions_Call) Run(run func(ctx context.Context, request *milvuspb.ShowPartitionsRequest)) *Proxy_ShowPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ShowPartitionsRequest)) + }) + return _c +} + +func (_c *Proxy_ShowPartitions_Call) Return(_a0 *milvuspb.ShowPartitionsResponse, _a1 error) *Proxy_ShowPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Start provides a mock function with given fields: +func (_m *Proxy) Start() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Proxy_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type Proxy_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +func (_e *Proxy_Expecter) Start() *Proxy_Start_Call { + return &Proxy_Start_Call{Call: _e.mock.On("Start")} +} + +func (_c *Proxy_Start_Call) Run(run func()) *Proxy_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Proxy_Start_Call) Return(_a0 error) *Proxy_Start_Call { + _c.Call.Return(_a0) + return _c +} + +// Stop provides a mock function with given fields: +func (_m *Proxy) Stop() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Proxy_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type Proxy_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *Proxy_Expecter) Stop() *Proxy_Stop_Call { + return &Proxy_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *Proxy_Stop_Call) Run(run func()) *Proxy_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Proxy_Stop_Call) Return(_a0 error) *Proxy_Stop_Call { + _c.Call.Return(_a0) + return _c +} + +// TransferNode provides a mock function with given fields: ctx, req +func (_m *Proxy) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferNodeRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.TransferNodeRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_TransferNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferNode' +type Proxy_TransferNode_Call struct { + *mock.Call +} + +// TransferNode is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.TransferNodeRequest +func (_e *Proxy_Expecter) TransferNode(ctx interface{}, req interface{}) *Proxy_TransferNode_Call { + return &Proxy_TransferNode_Call{Call: _e.mock.On("TransferNode", ctx, req)} +} + +func (_c *Proxy_TransferNode_Call) Run(run func(ctx context.Context, req *milvuspb.TransferNodeRequest)) *Proxy_TransferNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.TransferNodeRequest)) + }) + return _c +} + +func (_c *Proxy_TransferNode_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_TransferNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// TransferReplica provides a mock function with given fields: ctx, req +func (_m *Proxy) TransferReplica(ctx context.Context, req *milvuspb.TransferReplicaRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferReplicaRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.TransferReplicaRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_TransferReplica_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferReplica' +type Proxy_TransferReplica_Call struct { + *mock.Call +} + +// TransferReplica is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.TransferReplicaRequest +func (_e *Proxy_Expecter) TransferReplica(ctx interface{}, req interface{}) *Proxy_TransferReplica_Call { + return &Proxy_TransferReplica_Call{Call: _e.mock.On("TransferReplica", ctx, req)} +} + +func (_c *Proxy_TransferReplica_Call) Run(run func(ctx context.Context, req *milvuspb.TransferReplicaRequest)) *Proxy_TransferReplica_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.TransferReplicaRequest)) + }) + return _c +} + +func (_c *Proxy_TransferReplica_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_TransferReplica_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// UpdateCredential provides a mock function with given fields: ctx, req +func (_m *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, req) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.UpdateCredentialRequest) *commonpb.Status); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.UpdateCredentialRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential' +type Proxy_UpdateCredential_Call struct { + *mock.Call +} + +// UpdateCredential is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.UpdateCredentialRequest +func (_e *Proxy_Expecter) UpdateCredential(ctx interface{}, req interface{}) *Proxy_UpdateCredential_Call { + return &Proxy_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", ctx, req)} +} + +func (_c *Proxy_UpdateCredential_Call) Run(run func(ctx context.Context, req *milvuspb.UpdateCredentialRequest)) *Proxy_UpdateCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.UpdateCredentialRequest)) + }) + return _c +} + +func (_c *Proxy_UpdateCredential_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_UpdateCredential_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// UpdateCredentialCache provides a mock function with given fields: ctx, request +func (_m *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(ctx, request) + + var r0 *commonpb.Status + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) *commonpb.Status); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_UpdateCredentialCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredentialCache' +type Proxy_UpdateCredentialCache_Call struct { + *mock.Call +} + +// UpdateCredentialCache is a helper method to define mock.On call +// - ctx context.Context +// - request *proxypb.UpdateCredCacheRequest +func (_e *Proxy_Expecter) UpdateCredentialCache(ctx interface{}, request interface{}) *Proxy_UpdateCredentialCache_Call { + return &Proxy_UpdateCredentialCache_Call{Call: _e.mock.On("UpdateCredentialCache", ctx, request)} +} + +func (_c *Proxy_UpdateCredentialCache_Call) Run(run func(ctx context.Context, request *proxypb.UpdateCredCacheRequest)) *Proxy_UpdateCredentialCache_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.UpdateCredCacheRequest)) + }) + return _c +} + +func (_c *Proxy_UpdateCredentialCache_Call) Return(_a0 *commonpb.Status, _a1 error) *Proxy_UpdateCredentialCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// UpdateStateCode provides a mock function with given fields: stateCode +func (_m *Proxy) UpdateStateCode(stateCode commonpb.StateCode) { + _m.Called(stateCode) +} + +// Proxy_UpdateStateCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateStateCode' +type Proxy_UpdateStateCode_Call struct { + *mock.Call +} + +// UpdateStateCode is a helper method to define mock.On call +// - stateCode commonpb.StateCode +func (_e *Proxy_Expecter) UpdateStateCode(stateCode interface{}) *Proxy_UpdateStateCode_Call { + return &Proxy_UpdateStateCode_Call{Call: _e.mock.On("UpdateStateCode", stateCode)} +} + +func (_c *Proxy_UpdateStateCode_Call) Run(run func(stateCode commonpb.StateCode)) *Proxy_UpdateStateCode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(commonpb.StateCode)) + }) + return _c +} + +func (_c *Proxy_UpdateStateCode_Call) Return() *Proxy_UpdateStateCode_Call { + _c.Call.Return() + return _c +} + +// Upsert provides a mock function with given fields: ctx, request +func (_m *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) { + ret := _m.Called(ctx, request) + + var r0 *milvuspb.MutationResult + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.UpsertRequest) *milvuspb.MutationResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.MutationResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.UpsertRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Proxy_Upsert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Upsert' +type Proxy_Upsert_Call struct { + *mock.Call +} + +// Upsert is a helper method to define mock.On call +// - ctx context.Context +// - request *milvuspb.UpsertRequest +func (_e *Proxy_Expecter) Upsert(ctx interface{}, request interface{}) *Proxy_Upsert_Call { + return &Proxy_Upsert_Call{Call: _e.mock.On("Upsert", ctx, request)} +} + +func (_c *Proxy_Upsert_Call) Run(run func(ctx context.Context, request *milvuspb.UpsertRequest)) *Proxy_Upsert_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.UpsertRequest)) + }) + return _c +} + +func (_c *Proxy_Upsert_Call) Return(_a0 *milvuspb.MutationResult, _a1 error) *Proxy_Upsert_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +type mockConstructorTestingTNewProxy interface { + mock.TestingT + Cleanup(func()) +} + +// NewProxy creates a new instance of Proxy. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewProxy(t mockConstructorTestingTNewProxy) *Proxy { + mock := &Proxy{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index d3e1788fca..c0f3df3f07 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2411,9 +2411,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) ReqID: paramtable.GetNodeID(), }, request: request, - qc: node.queryCoord, tr: timerecord.NewTimeRecorder("search"), shardMgr: node.shardMgr, + qc: node.queryCoord, + node: node, } travelTs := request.TravelTimestamp diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 1d0831a324..72a90c7199 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -9,6 +9,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/samber/lo" "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/internal/parser/planparserv2" @@ -492,7 +493,10 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string { case *schemapb.IDs_IntId: idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]") case *schemapb.IDs_StrId: - idsStr = strings.Trim(strings.Join(ids.GetStrId().GetData(), ", "), "[]") + strs := lo.Map(ids.GetStrId().GetData(), func(str string, _ int) string { + return fmt.Sprintf("\"%s\"", str) + }) + idsStr = strings.Trim(strings.Join(strs, ", "), "[]") } return fieldName + " in [ " + idsStr + " ]" diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 5ed058599b..3c7f670a7b 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -869,3 +869,28 @@ func Test_queryTask_createPlan(t *testing.T) { assert.Error(t, err) }) } + +func TestQueryTask_IDs2Expr(t *testing.T) { + fieldName := "pk" + intIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5}, + }, + }, + } + stringIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"a", "b", "c"}, + }, + }, + } + idExpr := IDs2Expr(fieldName, intIDs) + expectIDExpr := "pk in [ 1, 2, 3, 4, 5 ]" + assert.Equal(t, expectIDExpr, idExpr) + + strExpr := IDs2Expr(fieldName, stringIDs) + expectStrExpr := "pk in [ \"a\", \"b\", \"c\" ]" + assert.Equal(t, expectStrExpr, strExpr) +} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 63bb20d8f1..afa07f3cff 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -3,11 +3,13 @@ package proxy import ( "context" "fmt" + "math" "regexp" "strconv" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "github.com/samber/lo" "go.opentelemetry.io/otel" "go.uber.org/zap" @@ -25,6 +27,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/distance" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -34,6 +37,12 @@ import ( const ( SearchTaskName = "SearchTask" SearchLevelKey = "level" + + // requeryThreshold is the estimated threshold for the size of the search results. + // If the number of estimated search results exceeds this threshold, + // a second query request will be initiated to retrieve output fields data. + // In this case, the first search will not return any output field from QueryNodes. + requeryThreshold = 0.5 * 1024 * 1024 ) type searchTask struct { @@ -41,13 +50,14 @@ type searchTask struct { *internalpb.SearchRequest ctx context.Context - result *milvuspb.SearchResults - request *milvuspb.SearchRequest - qc types.QueryCoord + result *milvuspb.SearchResults + request *milvuspb.SearchRequest + tr *timerecord.TimeRecorder collectionName string channelNum int32 schema *schemapb.CollectionSchema + requery bool offset int64 resultBuf chan *internalpb.SearchResults @@ -55,6 +65,9 @@ type searchTask struct { searchShardPolicy pickShardPolicy shardMgr *shardClientMgr + + qc types.QueryCoord + node types.ProxyComponent } func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) { @@ -164,11 +177,7 @@ func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string) hitField := false for _, field := range schema.GetFields() { if field.Name == name { - if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector { - return nil, errors.New("search doesn't support vector field as output_fields") - } outputFieldIDs = append(outputFieldIDs, field.GetFieldID()) - hitField = true break } @@ -255,6 +264,24 @@ func (t *searchTask) PreExecute(ctx context.Context) error { } t.SearchRequest.IgnoreGrowing = ignoreGrowing + // Manually update nq if not set. + nq, err := getNq(t.request) + if err != nil { + return err + } + // Check if nq is valid: + // https://milvus.io/docs/limitations.md + if err := validateLimit(nq); err != nil { + return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err) + } + t.SearchRequest.Nq = nq + + outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields()) + if err != nil { + return err + } + t.SearchRequest.OutputFieldsId = outputFieldIDs + if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) if err != nil { @@ -278,17 +305,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error { zap.String("dsl", t.request.Dsl), // may be very large if large term passed. zap.String("anns field", annsField), zap.Any("query info", queryInfo)) - outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields()) - if err != nil { - return err - } - - t.SearchRequest.OutputFieldsId = outputFieldIDs plan.OutputFieldIds = outputFieldIDs t.SearchRequest.Topk = queryInfo.GetTopk() t.SearchRequest.MetricType = queryInfo.GetMetricType() t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 + + estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk) + if err != nil { + return err + } + if estimateSize >= requeryThreshold { + t.requery = true + plan.OutputFieldIds = nil + } + t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan) if err != nil { return err @@ -319,17 +350,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error { t.SearchRequest.Dsl = t.request.Dsl t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup - // Manually update nq if not set. - nq, err := getNq(t.request) - if err != nil { - return err - } - // Check if nq is valid: - // https://milvus.io/docs/limitations.md - if err := validateLimit(nq); err != nil { - return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err) - } - t.SearchRequest.Nq = nq log.Ctx(ctx).Debug("search PreExecute done.", zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs), @@ -435,6 +455,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error { t.result.CollectionName = t.collectionName t.fillInFieldInfo() + if t.requery { + err = t.Requery() + if err != nil { + return err + } + } + log.Ctx(ctx).Debug("Search post execute done", zap.Int64("collection", t.GetCollectionID()), zap.Int64s("partitionIDs", t.GetPartitionIDs())) @@ -480,6 +507,93 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que return nil } +func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) { + vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { + return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType()) + }) + // Currently, we get vectors by requery. Once we support getting vectors from search, + // searches with small result size could no longer need requery. + if len(vectorOutputFields) > 0 { + return math.MaxInt64, nil + } + // If no vector field as output, no need to requery. + return 0, nil + + //outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { + // return lo.Contains(t.request.GetOutputFields(), field.GetName()) + //}) + //sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields}) + //if err != nil { + // return 0, err + //} + //return int64(sizePerRecord) * nq * topK, nil +} + +func (t *searchTask) Requery() error { + pkField, err := typeutil.GetPrimaryFieldSchema(t.schema) + if err != nil { + return err + } + ids := t.result.GetResults().GetIds() + expr := IDs2Expr(pkField.GetName(), ids) + + queryReq := &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + }, + CollectionName: t.request.GetCollectionName(), + Expr: expr, + OutputFields: t.request.GetOutputFields(), + PartitionNames: t.request.GetPartitionNames(), + TravelTimestamp: t.request.GetTravelTimestamp(), + GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(), + QueryParams: t.request.GetSearchParams(), + } + queryResult, err := t.node.Query(t.ctx, queryReq) + if err != nil { + return err + } + if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(queryResult.GetStatus()) + } + // Reorganize Results. The order of query result ids will be altered and differ from queried ids. + // We should reorganize query results to keep the order of original queried ids. For example: + // =========================================== + // 3 2 5 4 1 (query ids) + // || + // || (query) + // \/ + // 4 3 5 1 2 (result ids) + // v4 v3 v5 v1 v2 (result vectors) + // || + // || (reorganize) + // \/ + // 3 2 5 4 1 (result ids) + // v3 v2 v5 v4 v1 (result vectors) + // =========================================== + pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField) + if err != nil { + return err + } + offsets := make(map[any]int) + for i := 0; i < typeutil.GetDataSize(pkFieldData); i++ { + pk := typeutil.GetData(pkFieldData, i) + offsets[pk] = i + } + + t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData())) + for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ { + id := typeutil.GetPK(ids, int64(i)) + if _, ok := offsets[id]; !ok { + return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d", + id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID()) + } + typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id])) + } + + return nil +} + func (t *searchTask) fillInEmptyResult(numQueries int64) { t.result = &milvuspb.SearchResults{ Status: &commonpb.Status{ diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index bbda89cd68..8a249e09df 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -17,6 +17,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" @@ -268,7 +269,7 @@ func TestSearchTask_PreExecute(t *testing.T) { // contain vector field task.request.OutputFields = []string{testFloatVecField} - assert.Error(t, task.PreExecute(ctx)) + assert.NoError(t, task.PreExecute(ctx)) }) } @@ -1959,3 +1960,287 @@ func getSearchResultData(nq, topk int64) *schemapb.SearchResultData { } return &result } + +func TestSearchTask_Requery(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + dim = 128 + rows = 5 + collection = "test-requery" + + pkField = "pk" + vecField = "vec" + ) + + ids := make([]int64, rows) + for i := range ids { + ids[i] = int64(i) + } + + t.Run("Test normal", func(t *testing.T) { + schema := constructCollectionSchema(pkField, vecField, dim, collection) + node := mocks.NewProxy(t) + node.EXPECT().Query(mock.Anything, mock.Anything). + Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{{ + Type: schemapb.DataType_Int64, + FieldName: pkField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: ids, + }, + }, + }, + }, + }, + newFloatVectorFieldData(vecField, rows, dim), + }, + }, nil) + + resultIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + } + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + }, + request: &milvuspb.SearchRequest{}, + result: &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + Ids: resultIDs, + }, + }, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, + } + + err := qt.Requery() + assert.NoError(t, err) + }) + + t.Run("Test no primary key", func(t *testing.T) { + schema := &schemapb.CollectionSchema{} + node := mocks.NewProxy(t) + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + }, + request: &milvuspb.SearchRequest{}, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, + } + + err := qt.Requery() + t.Logf("err = %s", err) + assert.Error(t, err) + }) + + t.Run("Test requery failed 1", func(t *testing.T) { + schema := constructCollectionSchema(pkField, vecField, dim, collection) + node := mocks.NewProxy(t) + node.EXPECT().Query(mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("mock err 1")) + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + }, + request: &milvuspb.SearchRequest{}, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, + } + + err := qt.Requery() + t.Logf("err = %s", err) + assert.Error(t, err) + }) + + t.Run("Test requery failed 2", func(t *testing.T) { + schema := constructCollectionSchema(pkField, vecField, dim, collection) + node := mocks.NewProxy(t) + node.EXPECT().Query(mock.Anything, mock.Anything). + Return(&milvuspb.QueryResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mock err 2", + }, + }, nil) + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + }, + request: &milvuspb.SearchRequest{}, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, + } + + err := qt.Requery() + t.Logf("err = %s", err) + assert.Error(t, err) + }) + + t.Run("Test get pk filed data failed", func(t *testing.T) { + schema := constructCollectionSchema(pkField, vecField, dim, collection) + node := mocks.NewProxy(t) + node.EXPECT().Query(mock.Anything, mock.Anything). + Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{}, + }, nil) + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + }, + request: &milvuspb.SearchRequest{}, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, + } + + err := qt.Requery() + t.Logf("err = %s", err) + assert.Error(t, err) + }) + + t.Run("Test incomplete query result", func(t *testing.T) { + schema := constructCollectionSchema(pkField, vecField, dim, collection) + node := mocks.NewProxy(t) + node.EXPECT().Query(mock.Anything, mock.Anything). + Return(&milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{{ + Type: schemapb.DataType_Int64, + FieldName: pkField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: ids[:len(ids)-1], + }, + }, + }, + }, + }, + newFloatVectorFieldData(vecField, rows, dim), + }, + }, nil) + + resultIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + } + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + }, + request: &milvuspb.SearchRequest{}, + result: &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + Ids: resultIDs, + }, + }, + schema: schema, + tr: timerecord.NewTimeRecorder("search"), + node: node, + } + + err := qt.Requery() + t.Logf("err = %s", err) + assert.Error(t, err) + }) + + t.Run("Test postExecute with requery failed", func(t *testing.T) { + schema := constructCollectionSchema(pkField, vecField, dim, collection) + node := mocks.NewProxy(t) + node.EXPECT().Query(mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("mock err 1")) + + resultIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: ids, + }, + }, + } + + qt := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + SourceID: paramtable.GetNodeID(), + }, + }, + request: &milvuspb.SearchRequest{}, + result: &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + Ids: resultIDs, + }, + }, + requery: true, + schema: schema, + resultBuf: make(chan *internalpb.SearchResults, 10), + tr: timerecord.NewTimeRecorder("search"), + node: node, + } + scores := make([]float32, rows) + for i := range scores { + scores[i] = float32(i) + } + partialResultData := &schemapb.SearchResultData{ + Ids: resultIDs, + Scores: scores, + } + bytes, err := proto.Marshal(partialResultData) + assert.NoError(t, err) + qt.resultBuf <- &internalpb.SearchResults{ + SlicedBlob: bytes, + } + + err = qt.PostExecute(ctx) + t.Logf("err = %s", err) + assert.Error(t, err) + }) +} diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index f4a14485db..e836f5c9b0 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -282,6 +282,13 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool { return fieldInfo.IndexInfo != nil && fieldInfo.IndexInfo.EnableIndex } +func (s *LocalSegment) HasRawData(fieldID int64) bool { + s.mut.RLock() + defer s.mut.RUnlock() + ret := C.HasRawData(s.ptr, C.int64_t(fieldID)) + return bool(ret) +} + func (s *LocalSegment) Indexes() []*IndexedFieldInfo { var result []*IndexedFieldInfo s.fieldIndexes.Range(func(key int64, value *IndexedFieldInfo) bool { @@ -463,10 +470,18 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context, ) for _, fieldData := range result.FieldsData { - // If the vector field doesn't have indexed. Vector data is in memory for - // brute force search. No need to download data from remote. - if fieldData.GetType() != schemapb.DataType_FloatVector && fieldData.GetType() != schemapb.DataType_BinaryVector || - !s.ExistIndex(fieldData.FieldId) { + // If the field is not vector field, no need to download data from remote. + if !typeutil.IsVectorType(fieldData.GetType()) { + continue + } + // If the vector field doesn't have indexed, vector data is in memory + // for brute force search, no need to download data from remote. + if !s.ExistIndex(fieldData.FieldId) { + continue + } + // If the index has raw data, vector data could be obtained from index, + // no need to download data from remote. + if s.HasRawData(fieldData.FieldId) { continue } diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index a7c50cd323..affb901aa7 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -117,6 +117,13 @@ func (suite *SegmentSuite) TestDelete() { suite.Equal(rowNum, suite.growing.InsertCount()) } +func (suite *SegmentSuite) TestHasRawData() { + has := suite.growing.HasRawData(simpleFloatVecField.id) + suite.True(has) + has = suite.sealed.HasRawData(simpleFloatVecField.id) + suite.True(has) +} + func TestSegment(t *testing.T) { suite.Run(t, new(SegmentSuite)) } diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 21a45d7a33..17432e3c91 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -735,6 +735,30 @@ func GetSizeOfIDs(data *schemapb.IDs) int { return result } +func GetDataSize(fieldData *schemapb.FieldData) int { + switch fieldData.GetType() { + case schemapb.DataType_Bool: + return len(fieldData.GetScalars().GetBoolData().GetData()) + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + return len(fieldData.GetScalars().GetIntData().GetData()) + case schemapb.DataType_Int64: + return len(fieldData.GetScalars().GetLongData().GetData()) + case schemapb.DataType_Float: + return len(fieldData.GetScalars().GetFloatData().GetData()) + case schemapb.DataType_Double: + return len(fieldData.GetScalars().GetDoubleData().GetData()) + case schemapb.DataType_String: + return len(fieldData.GetScalars().GetStringData().GetData()) + case schemapb.DataType_VarChar: + return len(fieldData.GetScalars().GetStringData().GetData()) + case schemapb.DataType_FloatVector: + return len(fieldData.GetVectors().GetFloatVector().GetData()) + case schemapb.DataType_BinaryVector: + return len(fieldData.GetVectors().GetBinaryVector()) + } + return 0 +} + func IsPrimaryFieldType(dataType schemapb.DataType) bool { if dataType == schemapb.DataType_Int64 || dataType == schemapb.DataType_VarChar { return true @@ -756,6 +780,33 @@ func GetPK(data *schemapb.IDs, idx int64) interface{} { return nil } +func GetData(field *schemapb.FieldData, idx int) interface{} { + switch field.GetType() { + case schemapb.DataType_Bool: + return field.GetScalars().GetBoolData().GetData()[idx] + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + return field.GetScalars().GetIntData().GetData()[idx] + case schemapb.DataType_Int64: + return field.GetScalars().GetLongData().GetData()[idx] + case schemapb.DataType_Float: + return field.GetScalars().GetFloatData().GetData()[idx] + case schemapb.DataType_Double: + return field.GetScalars().GetDoubleData().GetData()[idx] + case schemapb.DataType_String: + return field.GetScalars().GetStringData().GetData()[idx] + case schemapb.DataType_VarChar: + return field.GetScalars().GetStringData().GetData()[idx] + case schemapb.DataType_FloatVector: + dim := int(field.GetVectors().GetDim()) + return field.GetVectors().GetFloatVector().GetData()[idx*dim : (idx+1)*dim] + case schemapb.DataType_BinaryVector: + dim := int(field.GetVectors().GetDim()) + dataBytes := dim / 8 + return field.GetVectors().GetBinaryVector()[idx*dataBytes : (idx+1)*dataBytes] + } + return nil +} + func AppendPKs(pks *schemapb.IDs, pk interface{}) { switch realPK := pk.(type) { case int64: diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index e40b4da6aa..9ec9e2c1cb 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -470,6 +470,21 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, }, FieldId: fieldID, } + case schemapb.DataType_String: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_String, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: fieldValue.([]string), + }, + }, + }, + }, + FieldId: fieldID, + } case schemapb.DataType_VarChar: fieldData = &schemapb.FieldData{ Type: schemapb.DataType_VarChar, @@ -990,3 +1005,94 @@ func TestCalcColumnSize(t *testing.T) { assert.Equal(t, expected, size, field.GetName()) } } + +func TestGetDataAndGetDataSize(t *testing.T) { + const ( + Dim = 8 + fieldName = "filed-0" + fieldID = 0 + ) + + BoolArray := []bool{true, false} + Int8Array := []int8{1, 2} + Int16Array := []int16{3, 4} + Int32Array := []int32{5, 6} + Int64Array := []int64{11, 22} + FloatArray := []float32{1.0, 2.0} + DoubleArray := []float64{11.0, 22.0} + VarCharArray := []string{"a", "b"} + StringArray := []string{"c", "d"} + BinaryVector := []byte{0x12, 0x34} + FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + + boolData := genFieldData(fieldName, fieldID, schemapb.DataType_Bool, BoolArray, 1) + int8Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int8, Int8Array, 1) + int16Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int16, Int16Array, 1) + int32Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int32, Int32Array, 1) + int64Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int64, Int64Array, 1) + floatData := genFieldData(fieldName, fieldID, schemapb.DataType_Float, FloatArray, 1) + doubleData := genFieldData(fieldName, fieldID, schemapb.DataType_Double, DoubleArray, 1) + varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1) + stringData := genFieldData(fieldName, fieldID, schemapb.DataType_String, StringArray, 1) + binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim) + floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim) + invalidData := &schemapb.FieldData{ + Type: schemapb.DataType_None, + } + + t.Run("test GetDataSize", func(t *testing.T) { + boolDataRes := GetDataSize(boolData) + int8DataRes := GetDataSize(int8Data) + int16DataRes := GetDataSize(int16Data) + int32DataRes := GetDataSize(int32Data) + int64DataRes := GetDataSize(int64Data) + floatDataRes := GetDataSize(floatData) + doubleDataRes := GetDataSize(doubleData) + varCharDataRes := GetDataSize(varCharData) + stringDataRes := GetDataSize(stringData) + binVecDataRes := GetDataSize(binVecData) + floatVecDataRes := GetDataSize(floatVecData) + invalidDataRes := GetDataSize(invalidData) + + assert.Equal(t, 2, boolDataRes) + assert.Equal(t, 2, int8DataRes) + assert.Equal(t, 2, int16DataRes) + assert.Equal(t, 2, int32DataRes) + assert.Equal(t, 2, int64DataRes) + assert.Equal(t, 2, floatDataRes) + assert.Equal(t, 2, doubleDataRes) + assert.Equal(t, 2, varCharDataRes) + assert.Equal(t, 2, stringDataRes) + assert.Equal(t, 2*Dim/8, binVecDataRes) + assert.Equal(t, 2*Dim, floatVecDataRes) + assert.Equal(t, 0, invalidDataRes) + }) + + t.Run("test GetData", func(t *testing.T) { + boolDataRes := GetData(boolData, 0) + int8DataRes := GetData(int8Data, 0) + int16DataRes := GetData(int16Data, 0) + int32DataRes := GetData(int32Data, 0) + int64DataRes := GetData(int64Data, 0) + floatDataRes := GetData(floatData, 0) + doubleDataRes := GetData(doubleData, 0) + varCharDataRes := GetData(varCharData, 0) + stringDataRes := GetData(stringData, 0) + binVecDataRes := GetData(binVecData, 0) + floatVecDataRes := GetData(floatVecData, 0) + invalidDataRes := GetData(invalidData, 0) + + assert.Equal(t, BoolArray[0], boolDataRes) + assert.Equal(t, int32(Int8Array[0]), int8DataRes) + assert.Equal(t, int32(Int16Array[0]), int16DataRes) + assert.Equal(t, Int32Array[0], int32DataRes) + assert.Equal(t, Int64Array[0], int64DataRes) + assert.Equal(t, FloatArray[0], floatDataRes) + assert.Equal(t, DoubleArray[0], doubleDataRes) + assert.Equal(t, VarCharArray[0], varCharDataRes) + assert.Equal(t, StringArray[0], stringDataRes) + assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes) + assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes) + assert.Nil(t, invalidDataRes) + }) +} diff --git a/tests/integration/bulkinsert_test.go b/tests/integration/bulkinsert_test.go index 262e41c809..2b26165e1a 100644 --- a/tests/integration/bulkinsert_test.go +++ b/tests/integration/bulkinsert_test.go @@ -51,72 +51,24 @@ const ( // 5, load // 6, search func TestBulkInsert(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*180) c, err := StartMiniCluster(ctx) assert.NoError(t, err) err = c.Start() assert.NoError(t, err) - defer c.Stop() - assert.NoError(t, err) + defer func() { + err = c.Stop() + assert.NoError(t, err) + cancel() + }() prefix := "TestBulkInsert" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - int64Field := "int64" - floatVecField := "embeddings" - scalarField := "image_path" + floatVecField := floatVecField dim := 128 - constructCollectionSchema := func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - Name: int64Field, - IsPrimaryKey: true, - Description: "pk", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: true, - } - fVec := &schemapb.FieldSchema{ - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - scalar := &schemapb.FieldSchema{ - Name: scalarField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_VarChar, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "max_length", - Value: "65535", - }, - }, - IndexParams: nil, - AutoID: false, - } - return &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - scalar, - }, - } - } - schema := constructCollectionSchema() + schema := constructSchema(collectionName, dim, true) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -207,28 +159,7 @@ func TestBulkInsert(t *testing.T) { CollectionName: collectionName, FieldName: floatVecField, IndexName: "_default", - ExtraParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: distance.L2, - }, - { - Key: "index_type", - Value: "HNSW", - }, - { - Key: "M", - Value: "64", - }, - { - Key: "efConstruction", - Value: "512", - }, - }, + ExtraParams: constructIndexParam(dim, IndexHNSW, distance.L2), }) if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason())) @@ -246,30 +177,17 @@ func TestBulkInsert(t *testing.T) { log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason())) } assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode()) - for { - loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ - CollectionName: collectionName, - }) - if err != nil { - panic("GetLoadingProgress fail") - } - if loadProgress.GetProgress() == 100 { - break - } - time.Sleep(500 * time.Millisecond) - } + waitingForLoad(ctx, c, collectionName) // search - expr := fmt.Sprintf("%s > 0", "int64") + expr := fmt.Sprintf("%s > 0", int64Field) nq := 10 topk := 10 roundDecimal := -1 - nprobe := 10 - params := make(map[string]int) - params["nprobe"] = nprobe + params := getSearchParams(IndexHNSW, distance.L2) searchReq := constructSearchRequest("", collectionName, expr, - floatVecField, distance.L2, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal) searchResult, err := c.proxy.Search(ctx, searchReq) diff --git a/tests/integration/get_vector_test.go b/tests/integration/get_vector_test.go new file mode 100644 index 0000000000..130768e659 --- /dev/null +++ b/tests/integration/get_vector_test.go @@ -0,0 +1,366 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/distance" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type TestGetVectorSuite struct { + suite.Suite + + ctx context.Context + cancel context.CancelFunc + cluster *MiniCluster + + // test params + nq int + topK int + indexType string + metricType string + pkType schemapb.DataType + vecType schemapb.DataType +} + +func (suite *TestGetVectorSuite) SetupTest() { + suite.ctx, suite.cancel = context.WithTimeout(context.Background(), time.Second*600) + + var err error + suite.cluster, err = StartMiniCluster(suite.ctx) + suite.Require().NoError(err) + err = suite.cluster.Start() + suite.Require().NoError(err) +} + +func (suite *TestGetVectorSuite) run() { + collection := fmt.Sprintf("TestGetVector_%d_%d_%s_%s_%s", + suite.nq, suite.topK, suite.indexType, suite.metricType, funcutil.GenRandomStr()) + + const ( + NB = 10000 + dim = 128 + ) + + pkFieldName := "pkField" + vecFieldName := "vecField" + pk := &schemapb.FieldSchema{ + FieldID: 100, + Name: pkFieldName, + IsPrimaryKey: true, + Description: "", + DataType: suite.pkType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "max_length", + Value: "100", + }, + }, + IndexParams: nil, + AutoID: false, + } + fVec := &schemapb.FieldSchema{ + FieldID: 101, + Name: vecFieldName, + IsPrimaryKey: false, + Description: "", + DataType: suite.vecType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: fmt.Sprintf("%d", dim), + }, + }, + IndexParams: nil, + } + schema := constructSchema(collection, dim, false, pk, fVec) + marshaledSchema, err := proto.Marshal(schema) + suite.Require().NoError(err) + + createCollectionStatus, err := suite.cluster.proxy.CreateCollection(suite.ctx, &milvuspb.CreateCollectionRequest{ + CollectionName: collection, + Schema: marshaledSchema, + ShardsNum: 2, + }) + suite.Require().NoError(err) + suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success) + + fieldsData := make([]*schemapb.FieldData, 0) + if suite.pkType == schemapb.DataType_Int64 { + fieldsData = append(fieldsData, newInt64FieldData(pkFieldName, NB)) + } else { + fieldsData = append(fieldsData, newStringFieldData(pkFieldName, NB)) + } + var vecFieldData *schemapb.FieldData + if suite.vecType == schemapb.DataType_FloatVector { + vecFieldData = newFloatVectorFieldData(vecFieldName, NB, dim) + } else { + vecFieldData = newBinaryVectorFieldData(vecFieldName, NB, dim) + } + fieldsData = append(fieldsData, vecFieldData) + hashKeys := generateHashKeys(NB) + _, err = suite.cluster.proxy.Insert(suite.ctx, &milvuspb.InsertRequest{ + CollectionName: collection, + FieldsData: fieldsData, + HashKeys: hashKeys, + NumRows: uint32(NB), + }) + suite.Require().NoError(err) + suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success) + + // flush + flushResp, err := suite.cluster.proxy.Flush(suite.ctx, &milvuspb.FlushRequest{ + CollectionNames: []string{collection}, + }) + suite.Require().NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collection] + ids := segmentIDs.GetData() + suite.Require().NotEmpty(segmentIDs) + suite.Require().True(has) + + segments, err := suite.cluster.metaWatcher.ShowSegments() + suite.Require().NoError(err) + suite.Require().NotEmpty(segments) + + waitingForFlush(suite.ctx, suite.cluster, ids) + + // create index + _, err = suite.cluster.proxy.CreateIndex(suite.ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collection, + FieldName: vecFieldName, + IndexName: "_default", + ExtraParams: constructIndexParam(dim, suite.indexType, suite.metricType), + }) + suite.Require().NoError(err) + suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success) + + // load + _, err = suite.cluster.proxy.LoadCollection(suite.ctx, &milvuspb.LoadCollectionRequest{ + CollectionName: collection, + }) + suite.Require().NoError(err) + suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success) + waitingForLoad(suite.ctx, suite.cluster, collection) + + // search + nq := suite.nq + topk := suite.topK + + outputFields := []string{vecFieldName} + params := getSearchParams(suite.indexType, suite.metricType) + searchReq := constructSearchRequest("", collection, "", + vecFieldName, suite.vecType, outputFields, suite.metricType, params, nq, dim, topk, -1) + + searchResp, err := suite.cluster.proxy.Search(suite.ctx, searchReq) + suite.Require().NoError(err) + suite.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + + result := searchResp.GetResults() + if suite.pkType == schemapb.DataType_Int64 { + suite.Require().Len(result.GetIds().GetIntId().GetData(), nq*topk) + } else { + suite.Require().Len(result.GetIds().GetStrId().GetData(), nq*topk) + } + suite.Require().Len(result.GetScores(), nq*topk) + suite.Require().GreaterOrEqual(len(result.GetFieldsData()), 1) + var vecFieldIndex = -1 + for i, fieldData := range result.GetFieldsData() { + if typeutil.IsVectorType(fieldData.GetType()) { + vecFieldIndex = i + break + } + } + suite.Require().EqualValues(nq, result.GetNumQueries()) + suite.Require().EqualValues(topk, result.GetTopK()) + + // check output vectors + if suite.vecType == schemapb.DataType_FloatVector { + suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData(), nq*topk*dim) + rawData := vecFieldData.GetVectors().GetFloatVector().GetData() + resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData() + if suite.pkType == schemapb.DataType_Int64 { + for i, id := range result.GetIds().GetIntId().GetData() { + expect := rawData[int(id)*dim : (int(id)+1)*dim] + actual := resData[i*dim : (i+1)*dim] + suite.Require().ElementsMatch(expect, actual) + } + } else { + for i, idStr := range result.GetIds().GetStrId().GetData() { + id, err := strconv.Atoi(idStr) + suite.Require().NoError(err) + expect := rawData[id*dim : (id+1)*dim] + actual := resData[i*dim : (i+1)*dim] + suite.Require().ElementsMatch(expect, actual) + } + } + } else { + suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector(), nq*topk*dim/8) + rawData := vecFieldData.GetVectors().GetBinaryVector() + resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector() + if suite.pkType == schemapb.DataType_Int64 { + for i, id := range result.GetIds().GetIntId().GetData() { + dataBytes := dim / 8 + for j := 0; j < dataBytes; j++ { + expect := rawData[int(id)*dataBytes+j] + actual := resData[i*dataBytes+j] + suite.Require().Equal(expect, actual) + } + } + } else { + for i, idStr := range result.GetIds().GetStrId().GetData() { + dataBytes := dim / 8 + id, err := strconv.Atoi(idStr) + suite.Require().NoError(err) + for j := 0; j < dataBytes; j++ { + expect := rawData[id*dataBytes+j] + actual := resData[i*dataBytes+j] + suite.Require().Equal(expect, actual) + } + } + } + } + + status, err := suite.cluster.proxy.DropCollection(suite.ctx, &milvuspb.DropCollectionRequest{ + CollectionName: collection, + }) + suite.Require().NoError(err) + suite.Require().Equal(status.GetErrorCode(), commonpb.ErrorCode_Success) +} + +func (suite *TestGetVectorSuite) TestGetVector_FLAT() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexFaissIDMap + suite.metricType = distance.L2 + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_IVF_FLAT() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexFaissIvfFlat + suite.metricType = distance.L2 + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_IVF_PQ() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexFaissIvfPQ + suite.metricType = distance.L2 + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_IVF_SQ8() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexFaissIvfSQ8 + suite.metricType = distance.L2 + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_HNSW() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexHNSW + suite.metricType = distance.L2 + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_IP() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexHNSW + suite.metricType = distance.IP + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_StringPK() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexHNSW + suite.metricType = distance.L2 + suite.pkType = schemapb.DataType_VarChar + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_BinaryVector() { + suite.nq = 10 + suite.topK = 10 + suite.indexType = IndexFaissBinIvfFlat + suite.metricType = distance.JACCARD + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_BinaryVector + suite.run() +} + +func (suite *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() { + suite.nq = 10000 + suite.topK = 200 + suite.indexType = IndexHNSW + suite.metricType = distance.L2 + suite.pkType = schemapb.DataType_Int64 + suite.vecType = schemapb.DataType_FloatVector + suite.run() +} + +//func (suite *TestGetVectorSuite) TestGetVector_DISKANN() { +// suite.nq = 10 +// suite.topK = 10 +// suite.indexType = IndexDISKANN +// suite.metricType = distance.L2 +// suite.pkType = schemapb.DataType_Int64 +// suite.vecType = schemapb.DataType_FloatVector +// suite.run() +//} + +func (suite *TestGetVectorSuite) TearDownTest() { + err := suite.cluster.Stop() + suite.Require().NoError(err) + suite.cancel() +} + +func TestGetVector(t *testing.T) { + suite.Run(t, new(TestGetVectorSuite)) +} diff --git a/tests/integration/hello_milvus_test.go b/tests/integration/hello_milvus_test.go index ae0577e280..487869f8f2 100644 --- a/tests/integration/hello_milvus_test.go +++ b/tests/integration/hello_milvus_test.go @@ -17,17 +17,11 @@ package integration import ( - "bytes" "context" - "encoding/binary" - "encoding/json" "fmt" - "math/rand" - "strconv" "testing" "time" - "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -42,59 +36,27 @@ import ( ) func TestHelloMilvus(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*180) + defer cancel() c, err := StartMiniCluster(ctx) assert.NoError(t, err) err = c.Start() assert.NoError(t, err) - defer c.Stop() - assert.NoError(t, err) + defer func() { + err = c.Stop() + assert.NoError(t, err) + cancel() + }() - prefix := "TestHelloMilvus" - dbName := "" - collectionName := prefix + funcutil.GenRandomStr() - int64Field := "int64" - floatVecField := "fvec" - dim := 128 - rowNum := 3000 + const ( + dim = 128 + dbName = "" + rowNum = 3000 + ) - constructCollectionSchema := func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - FieldID: 0, - Name: int64Field, - IsPrimaryKey: true, - Description: "", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: true, - } - fVec := &schemapb.FieldSchema{ - FieldID: 0, - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - return &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - }, - } - } - schema := constructCollectionSchema() + collectionName := "TestHelloMilvus" + funcutil.GenRandomStr() + + schema := constructSchema(collectionName, dim, true) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -137,6 +99,7 @@ func TestHelloMilvus(t *testing.T) { segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] ids := segmentIDs.GetData() assert.NotEmpty(t, segmentIDs) + assert.True(t, has) segments, err := c.metaWatcher.ShowSegments() assert.NoError(t, err) @@ -144,52 +107,14 @@ func TestHelloMilvus(t *testing.T) { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - - if has && len(ids) > 0 { - flushed := func() bool { - resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: ids, - }) - if err != nil { - //panic(errors.New("GetFlushState failed")) - return false - } - return resp.GetFlushed() - } - for !flushed() { - // respect context deadline/cancel - select { - case <-ctx.Done(): - panic(errors.New("deadline exceeded")) - default: - } - time.Sleep(500 * time.Millisecond) - } - } + waitingForFlush(ctx, c, ids) // create index createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ CollectionName: collectionName, FieldName: floatVecField, IndexName: "_default", - ExtraParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: distance.L2, - }, - { - Key: "index_type", - Value: "IVF_FLAT", - }, - { - Key: "nlist", - Value: strconv.Itoa(10), - }, - }, + ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2), }) if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason())) @@ -207,30 +132,17 @@ func TestHelloMilvus(t *testing.T) { log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason())) } assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode()) - for { - loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ - CollectionName: collectionName, - }) - if err != nil { - panic("GetLoadingProgress fail") - } - if loadProgress.GetProgress() == 100 { - break - } - time.Sleep(500 * time.Millisecond) - } + waitingForLoad(ctx, c, collectionName) // search - expr := fmt.Sprintf("%s > 0", "int64") + expr := fmt.Sprintf("%s > 0", int64Field) nq := 10 topk := 10 roundDecimal := -1 - nprobe := 10 - params := make(map[string]int) - params["nprobe"] = nprobe + params := getSearchParams(IndexFaissIvfFlat, distance.L2) searchReq := constructSearchRequest("", collectionName, expr, - floatVecField, distance.L2, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal) searchResult, err := c.proxy.Search(ctx, searchReq) @@ -242,155 +154,3 @@ func TestHelloMilvus(t *testing.T) { log.Info("TestHelloMilvus succeed") } - -const ( - AnnsFieldKey = "anns_field" - TopKKey = "topk" - NQKey = "nq" - MetricTypeKey = "metric_type" - SearchParamsKey = "params" - RoundDecimalKey = "round_decimal" - OffsetKey = "offset" - LimitKey = "limit" -) - -func constructSearchRequest( - dbName, collectionName string, - expr string, - floatVecField string, - metricType string, - params map[string]int, - nq, dim, topk, roundDecimal int, -) *milvuspb.SearchRequest { - b, err := json.Marshal(params) - if err != nil { - panic(err) - } - plg := constructPlaceholderGroup(nq, dim) - plgBs, err := proto.Marshal(plg) - if err != nil { - panic(err) - } - - return &milvuspb.SearchRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionNames: nil, - Dsl: expr, - PlaceholderGroup: plgBs, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: nil, - SearchParams: []*commonpb.KeyValuePair{ - { - Key: common.MetricTypeKey, - Value: metricType, - }, - { - Key: SearchParamsKey, - Value: string(b), - }, - { - Key: AnnsFieldKey, - Value: floatVecField, - }, - { - Key: TopKKey, - Value: strconv.Itoa(topk), - }, - { - Key: RoundDecimalKey, - Value: strconv.Itoa(roundDecimal), - }, - }, - TravelTimestamp: 0, - GuaranteeTimestamp: 0, - } -} - -func constructPlaceholderGroup( - nq, dim int, -) *commonpb.PlaceholderGroup { - values := make([][]byte, 0, nq) - for i := 0; i < nq; i++ { - bs := make([]byte, 0, dim*4) - for j := 0; j < dim; j++ { - var buffer bytes.Buffer - f := rand.Float32() - err := binary.Write(&buffer, common.Endian, f) - if err != nil { - panic(err) - } - bs = append(bs, buffer.Bytes()...) - } - values = append(values, bs) - } - - return &commonpb.PlaceholderGroup{ - Placeholders: []*commonpb.PlaceholderValue{ - { - Tag: "$0", - Type: commonpb.PlaceholderType_FloatVector, - Values: values, - }, - }, - } -} - -func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData { - return &schemapb.FieldData{ - Type: schemapb.DataType_FloatVector, - FieldName: fieldName, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: generateFloatVectors(numRows, dim), - }, - }, - }, - }, - } -} - -func newInt64PrimaryKey(fieldName string, numRows int) *schemapb.FieldData { - return &schemapb.FieldData{ - Type: schemapb.DataType_Int64, - FieldName: fieldName, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: generateInt64Array(numRows), - }, - }, - }, - }, - } -} - -func generateFloatVectors(numRows, dim int) []float32 { - total := numRows * dim - ret := make([]float32, 0, total) - for i := 0; i < total; i++ { - ret = append(ret, rand.Float32()) - } - return ret -} - -func generateInt64Array(numRows int) []int64 { - ret := make([]int64, 0, numRows) - for i := 0; i < numRows; i++ { - ret = append(ret, int64(rand.Int())) - } - return ret -} - -func generateHashKeys(numRows int) []uint32 { - ret := make([]uint32, 0, numRows) - for i := 0; i < numRows; i++ { - ret = append(ret, rand.Uint32()) - } - return ret -} diff --git a/tests/integration/range_search_test.go b/tests/integration/range_search_test.go index 3fb419faba..79a4d76420 100644 --- a/tests/integration/range_search_test.go +++ b/tests/integration/range_search_test.go @@ -19,16 +19,13 @@ package integration import ( "context" "fmt" - "strconv" "testing" "time" - "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/pkg/common" @@ -39,59 +36,24 @@ import ( ) func TestRangeSearchIP(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*180) c, err := StartMiniCluster(ctx) assert.NoError(t, err) err = c.Start() assert.NoError(t, err) - defer c.Stop() - assert.NoError(t, err) + defer func() { + err = c.Stop() + assert.NoError(t, err) + cancel() + }() prefix := "TestRangeSearchIP" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - int64Field := "int64" - floatVecField := "fvec" dim := 128 rowNum := 3000 - constructCollectionSchema := func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - FieldID: 0, - Name: int64Field, - IsPrimaryKey: true, - Description: "", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: true, - } - fVec := &schemapb.FieldSchema{ - FieldID: 0, - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - return &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - }, - } - } - schema := constructCollectionSchema() + schema := constructSchema(collectionName, dim, true) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -133,6 +95,7 @@ func TestRangeSearchIP(t *testing.T) { }) assert.NoError(t, err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + assert.True(t, has) ids := segmentIDs.GetData() assert.NotEmpty(t, segmentIDs) @@ -142,52 +105,14 @@ func TestRangeSearchIP(t *testing.T) { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - - if has && len(ids) > 0 { - flushed := func() bool { - resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: ids, - }) - if err != nil { - //panic(errors.New("GetFlushState failed")) - return false - } - return resp.GetFlushed() - } - for !flushed() { - // respect context deadline/cancel - select { - case <-ctx.Done(): - panic(errors.New("deadline exceeded")) - default: - } - time.Sleep(500 * time.Millisecond) - } - } + waitingForFlush(ctx, c, ids) // create index createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ CollectionName: collectionName, FieldName: floatVecField, IndexName: "_default", - ExtraParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: distance.IP, - }, - { - Key: "index_type", - Value: "IVF_FLAT", - }, - { - Key: "nlist", - Value: strconv.Itoa(10), - }, - }, + ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP), }) assert.NoError(t, err) err = merr.Error(createIndexStatus) @@ -205,34 +130,21 @@ func TestRangeSearchIP(t *testing.T) { if err != nil { log.Warn("LoadCollection fail reason", zap.Error(err)) } - for { - loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ - CollectionName: collectionName, - }) - if err != nil { - panic("GetLoadingProgress fail") - } - if loadProgress.GetProgress() == 100 { - break - } - time.Sleep(500 * time.Millisecond) - } + waitingForLoad(ctx, c, collectionName) // search - expr := fmt.Sprintf("%s > 0", "int64") + expr := fmt.Sprintf("%s > 0", int64Field) nq := 10 topk := 10 roundDecimal := -1 - nprobe := 10 radius := 10 filter := 20 - params := make(map[string]int) - params["nprobe"] = nprobe + params := getSearchParams(IndexFaissIvfFlat, distance.IP) // only pass in radius when range search params["radius"] = radius searchReq := constructSearchRequest("", collectionName, expr, - floatVecField, distance.IP, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal) searchResult, _ := c.proxy.Search(ctx, searchReq) @@ -245,7 +157,7 @@ func TestRangeSearchIP(t *testing.T) { // pass in radius and range_filter when range search params["range_filter"] = filter searchReq = constructSearchRequest("", collectionName, expr, - floatVecField, distance.IP, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal) searchResult, _ = c.proxy.Search(ctx, searchReq) @@ -259,7 +171,7 @@ func TestRangeSearchIP(t *testing.T) { params["radius"] = filter params["range_filter"] = radius searchReq = constructSearchRequest("", collectionName, expr, - floatVecField, distance.IP, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal) searchResult, _ = c.proxy.Search(ctx, searchReq) @@ -277,59 +189,24 @@ func TestRangeSearchIP(t *testing.T) { } func TestRangeSearchL2(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*180) c, err := StartMiniCluster(ctx) assert.NoError(t, err) err = c.Start() assert.NoError(t, err) - defer c.Stop() - assert.NoError(t, err) + defer func() { + err = c.Stop() + assert.NoError(t, err) + cancel() + }() prefix := "TestRangeSearchL2" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - int64Field := "int64" - floatVecField := "fvec" dim := 128 rowNum := 3000 - constructCollectionSchema := func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - FieldID: 0, - Name: int64Field, - IsPrimaryKey: true, - Description: "", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: true, - } - fVec := &schemapb.FieldSchema{ - FieldID: 0, - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - return &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - }, - } - } - schema := constructCollectionSchema() + schema := constructSchema(collectionName, dim, true) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -371,6 +248,7 @@ func TestRangeSearchL2(t *testing.T) { }) assert.NoError(t, err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + assert.True(t, has) ids := segmentIDs.GetData() assert.NotEmpty(t, segmentIDs) @@ -380,52 +258,14 @@ func TestRangeSearchL2(t *testing.T) { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - - if has && len(ids) > 0 { - flushed := func() bool { - resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: ids, - }) - if err != nil { - //panic(errors.New("GetFlushState failed")) - return false - } - return resp.GetFlushed() - } - for !flushed() { - // respect context deadline/cancel - select { - case <-ctx.Done(): - panic(errors.New("deadline exceeded")) - default: - } - time.Sleep(500 * time.Millisecond) - } - } + waitingForFlush(ctx, c, ids) // create index createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ CollectionName: collectionName, FieldName: floatVecField, IndexName: "_default", - ExtraParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: distance.L2, - }, - { - Key: "index_type", - Value: "IVF_FLAT", - }, - { - Key: "nlist", - Value: strconv.Itoa(10), - }, - }, + ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2), }) assert.NoError(t, err) err = merr.Error(createIndexStatus) @@ -443,34 +283,20 @@ func TestRangeSearchL2(t *testing.T) { if err != nil { log.Warn("LoadCollection fail reason", zap.Error(err)) } - for { - loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ - CollectionName: collectionName, - }) - if err != nil { - panic("GetLoadingProgress fail") - } - if loadProgress.GetProgress() == 100 { - break - } - time.Sleep(500 * time.Millisecond) - } + waitingForLoad(ctx, c, collectionName) // search - expr := fmt.Sprintf("%s > 0", "int64") + expr := fmt.Sprintf("%s > 0", int64Field) nq := 10 topk := 10 roundDecimal := -1 - nprobe := 10 radius := 20 filter := 10 - params := make(map[string]int) - params["nprobe"] = nprobe - + params := getSearchParams(IndexFaissIvfFlat, distance.L2) // only pass in radius when range search params["radius"] = radius searchReq := constructSearchRequest("", collectionName, expr, - floatVecField, distance.L2, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal) searchResult, _ := c.proxy.Search(ctx, searchReq) @@ -483,7 +309,7 @@ func TestRangeSearchL2(t *testing.T) { // pass in radius and range_filter when range search params["range_filter"] = filter searchReq = constructSearchRequest("", collectionName, expr, - floatVecField, distance.L2, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal) searchResult, _ = c.proxy.Search(ctx, searchReq) @@ -497,7 +323,7 @@ func TestRangeSearchL2(t *testing.T) { params["radius"] = filter params["range_filter"] = radius searchReq = constructSearchRequest("", collectionName, expr, - floatVecField, distance.L2, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal) searchResult, _ = c.proxy.Search(ctx, searchReq) diff --git a/tests/integration/upsert_test.go b/tests/integration/upsert_test.go index 07afff99d5..81711aea49 100644 --- a/tests/integration/upsert_test.go +++ b/tests/integration/upsert_test.go @@ -19,13 +19,10 @@ package integration import ( "context" "fmt" - "strconv" "testing" "time" - "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/pkg/common" @@ -38,59 +35,25 @@ import ( ) func TestUpsert(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*180) c, err := StartMiniCluster(ctx) assert.NoError(t, err) err = c.Start() assert.NoError(t, err) - defer c.Stop() + defer func() { + err = c.Stop() + assert.NoError(t, err) + cancel() + }() assert.NoError(t, err) prefix := "TestUpsert" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - int64Field := "int64" - floatVecField := "fvec" dim := 128 rowNum := 3000 - constructCollectionSchema := func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - FieldID: 0, - Name: int64Field, - IsPrimaryKey: true, - Description: "", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: false, - } - fVec := &schemapb.FieldSchema{ - FieldID: 0, - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - return &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - }, - } - } - schema := constructCollectionSchema() + schema := constructSchema(collectionName, dim, false) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -113,7 +76,7 @@ func TestUpsert(t *testing.T) { assert.True(t, merr.Ok(showCollectionsResp.GetStatus())) log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) - pkFieldData := newInt64PrimaryKey(int64Field, rowNum) + pkFieldData := newInt64FieldData(int64Field, rowNum) fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) hashKeys := generateHashKeys(rowNum) upsertResult, err := c.proxy.Upsert(ctx, &milvuspb.UpsertRequest{ @@ -133,6 +96,7 @@ func TestUpsert(t *testing.T) { }) assert.NoError(t, err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + assert.True(t, has) ids := segmentIDs.GetData() assert.NotEmpty(t, segmentIDs) @@ -142,52 +106,14 @@ func TestUpsert(t *testing.T) { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - - if has && len(ids) > 0 { - flushed := func() bool { - resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: ids, - }) - if err != nil { - //panic(errors.New("GetFlushState failed")) - return false - } - return resp.GetFlushed() - } - for !flushed() { - // respect context deadline/cancel - select { - case <-ctx.Done(): - panic(errors.New("deadline exceeded")) - default: - } - time.Sleep(500 * time.Millisecond) - } - } + waitingForFlush(ctx, c, ids) // create index createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ CollectionName: collectionName, FieldName: floatVecField, IndexName: "_default", - ExtraParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: distance.L2, - }, - { - Key: "index_type", - Value: "IVF_FLAT", - }, - { - Key: "nlist", - Value: strconv.Itoa(10), - }, - }, + ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP), }) assert.NoError(t, err) err = merr.Error(createIndexStatus) @@ -205,30 +131,16 @@ func TestUpsert(t *testing.T) { if err != nil { log.Warn("LoadCollection fail reason", zap.Error(err)) } - for { - loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ - CollectionName: collectionName, - }) - if err != nil { - panic("GetLoadingProgress fail") - } - if loadProgress.GetProgress() == 100 { - break - } - time.Sleep(500 * time.Millisecond) - } + waitingForLoad(ctx, c, collectionName) // search - expr := fmt.Sprintf("%s > 0", "int64") + expr := fmt.Sprintf("%s > 0", int64Field) nq := 10 topk := 10 roundDecimal := -1 - nprobe := 10 - - params := make(map[string]int) - params["nprobe"] = nprobe + params := getSearchParams(IndexFaissIvfFlat, "") searchReq := constructSearchRequest("", collectionName, expr, - floatVecField, distance.IP, params, nq, dim, topk, roundDecimal) + floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal) searchResult, _ := c.proxy.Search(ctx, searchReq) diff --git a/tests/integration/util_index.go b/tests/integration/util_index.go new file mode 100644 index 0000000000..e3cb58e95d --- /dev/null +++ b/tests/integration/util_index.go @@ -0,0 +1,108 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "fmt" + "strconv" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" +) + +const ( + IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat + IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ + IndexFaissIDMap = indexparamcheck.IndexFaissIDMap + IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat + IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ + IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8 + IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap + IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat + IndexHNSW = indexparamcheck.IndexHNSW + IndexDISKANN = indexparamcheck.IndexDISKANN +) + +func constructIndexParam(dim int, indexType string, metricType string) []*commonpb.KeyValuePair { + params := []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + { + Key: common.MetricTypeKey, + Value: metricType, + }, + { + Key: common.IndexTypeKey, + Value: indexType, + }, + } + switch indexType { + case IndexFaissIDMap, IndexFaissBinIDMap: + // no index param is required + case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8: + params = append(params, &commonpb.KeyValuePair{ + Key: "nlist", + Value: "100", + }) + case IndexFaissIvfPQ: + params = append(params, &commonpb.KeyValuePair{ + Key: "nlist", + Value: "100", + }) + params = append(params, &commonpb.KeyValuePair{ + Key: "m", + Value: "16", + }) + params = append(params, &commonpb.KeyValuePair{ + Key: "nbits", + Value: "8", + }) + case IndexHNSW: + params = append(params, &commonpb.KeyValuePair{ + Key: "M", + Value: "16", + }) + params = append(params, &commonpb.KeyValuePair{ + Key: "efConstruction", + Value: "200", + }) + case IndexDISKANN: + default: + panic(fmt.Sprintf("unimplemented index param for %s, please help to improve it", indexType)) + } + return params +} + +func getSearchParams(indexType string, metricType string) map[string]any { + params := make(map[string]any) + switch indexType { + case IndexFaissIDMap, IndexFaissBinIDMap: + params["metric_type"] = metricType + case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ: + params["nprobe"] = 8 + case IndexHNSW: + params["ef"] = 200 + case IndexDISKANN: + params["search_list"] = 5 + default: + panic(fmt.Sprintf("unimplemented search param for %s, please help to improve it", indexType)) + } + return params +} diff --git a/tests/integration/util_insert.go b/tests/integration/util_insert.go new file mode 100644 index 0000000000..58508a4e2b --- /dev/null +++ b/tests/integration/util_insert.go @@ -0,0 +1,154 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + "fmt" + "math/rand" + "time" + + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" +) + +func waitingForFlush(ctx context.Context, cluster *MiniCluster, segIDs []int64) { + flushed := func() bool { + resp, err := cluster.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ + SegmentIDs: segIDs, + }) + if err != nil { + return false + } + return resp.GetFlushed() + } + for !flushed() { + select { + case <-ctx.Done(): + panic("flush timeout") + default: + time.Sleep(500 * time.Millisecond) + } + } +} + +func newInt64FieldData(fieldName string, numRows int) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: generateInt64Array(numRows), + }, + }, + }, + }, + } +} + +func newStringFieldData(fieldName string, numRows int) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: generateStringArray(numRows), + }, + }, + }, + }, + } +} + +func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: generateFloatVectors(numRows, dim), + }, + }, + }, + }, + } +} + +func newBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_BinaryVector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: generateBinaryVectors(numRows, dim), + }, + }, + }, + } +} + +func generateInt64Array(numRows int) []int64 { + ret := make([]int64, numRows) + for i := 0; i < numRows; i++ { + ret[i] = int64(i) + } + return ret +} + +func generateStringArray(numRows int) []string { + ret := make([]string, numRows) + for i := 0; i < numRows; i++ { + ret[i] = fmt.Sprintf("%d", i) + } + return ret +} + +func generateFloatVectors(numRows, dim int) []float32 { + total := numRows * dim + ret := make([]float32, 0, total) + for i := 0; i < total; i++ { + ret = append(ret, rand.Float32()) + } + return ret +} + +func generateBinaryVectors(numRows, dim int) []byte { + total := (numRows * dim) / 8 + ret := make([]byte, total) + _, err := rand.Read(ret) + if err != nil { + panic(err) + } + return ret +} + +func generateHashKeys(numRows int) []uint32 { + ret := make([]uint32, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, rand.Uint32()) + } + return ret +} diff --git a/tests/integration/util_query.go b/tests/integration/util_query.go new file mode 100644 index 0000000000..8f055152a3 --- /dev/null +++ b/tests/integration/util_query.go @@ -0,0 +1,166 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "math/rand" + "strconv" + "time" + + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/pkg/common" +) + +const ( + AnnsFieldKey = "anns_field" + TopKKey = "topk" + NQKey = "nq" + MetricTypeKey = "metric_type" + SearchParamsKey = "params" + RoundDecimalKey = "round_decimal" + OffsetKey = "offset" + LimitKey = "limit" +) + +func waitingForLoad(ctx context.Context, cluster *MiniCluster, collection string) { + getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse { + loadProgress, err := cluster.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ + CollectionName: collection, + }) + if err != nil { + panic("GetLoadingProgress fail") + } + return loadProgress + } + for getLoadingProgress().GetProgress() != 100 { + select { + case <-ctx.Done(): + panic("load timeout") + default: + time.Sleep(500 * time.Millisecond) + } + } +} + +func constructSearchRequest( + dbName, collectionName string, + expr string, + vecField string, + vectorType schemapb.DataType, + outputFields []string, + metricType string, + params map[string]any, + nq, dim int, topk, roundDecimal int, +) *milvuspb.SearchRequest { + b, err := json.Marshal(params) + if err != nil { + panic(err) + } + plg := constructPlaceholderGroup(nq, dim, vectorType) + plgBs, err := proto.Marshal(plg) + if err != nil { + panic(err) + } + + return &milvuspb.SearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Dsl: expr, + PlaceholderGroup: plgBs, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: outputFields, + SearchParams: []*commonpb.KeyValuePair{ + { + Key: common.MetricTypeKey, + Value: metricType, + }, + { + Key: SearchParamsKey, + Value: string(b), + }, + { + Key: AnnsFieldKey, + Value: vecField, + }, + { + Key: common.TopKKey, + Value: strconv.Itoa(topk), + }, + { + Key: RoundDecimalKey, + Value: strconv.Itoa(roundDecimal), + }, + }, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + } +} + +func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup { + values := make([][]byte, 0, nq) + var placeholderType commonpb.PlaceholderType + switch vectorType { + case schemapb.DataType_FloatVector: + placeholderType = commonpb.PlaceholderType_FloatVector + for i := 0; i < nq; i++ { + bs := make([]byte, 0, dim*4) + for j := 0; j < dim; j++ { + var buffer bytes.Buffer + f := rand.Float32() + err := binary.Write(&buffer, common.Endian, f) + if err != nil { + panic(err) + } + bs = append(bs, buffer.Bytes()...) + } + values = append(values, bs) + } + case schemapb.DataType_BinaryVector: + placeholderType = commonpb.PlaceholderType_BinaryVector + for i := 0; i < nq; i++ { + total := dim / 8 + ret := make([]byte, total) + _, err := rand.Read(ret) + if err != nil { + panic(err) + } + values = append(values, ret) + } + default: + panic("invalid vector data type") + } + + return &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + { + Tag: "$0", + Type: placeholderType, + Values: values, + }, + }, + } +} diff --git a/tests/integration/util_schema.go b/tests/integration/util_schema.go new file mode 100644 index 0000000000..16d94a1d4b --- /dev/null +++ b/tests/integration/util_schema.go @@ -0,0 +1,80 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/pkg/common" +) + +const ( + boolField = "boolField" + int8Field = "int8Field" + int16Field = "int16Field" + int32Field = "int32Field" + int64Field = "int64Field" + floatField = "floatField" + doubleField = "doubleField" + varCharField = "varCharField" + floatVecField = "floatVecField" + binVecField = "binVecField" +) + +func constructSchema(collection string, dim int, autoID bool, fields ...*schemapb.FieldSchema) *schemapb.CollectionSchema { + // if fields are specified, construct it + if len(fields) > 0 { + return &schemapb.CollectionSchema{ + Name: collection, + AutoID: autoID, + Fields: fields, + } + } + + // if no field is specified, use default + pk := &schemapb.FieldSchema{ + FieldID: 100, + Name: int64Field, + IsPrimaryKey: true, + Description: "", + DataType: schemapb.DataType_Int64, + TypeParams: nil, + IndexParams: nil, + AutoID: autoID, + } + fVec := &schemapb.FieldSchema{ + FieldID: 101, + Name: floatVecField, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: fmt.Sprintf("%d", dim), + }, + }, + IndexParams: nil, + } + return &schemapb.CollectionSchema{ + Name: collection, + AutoID: autoID, + Fields: []*schemapb.FieldSchema{pk, fVec}, + } +} diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index feee3888ac..01abcba580 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -842,11 +842,7 @@ class TestCollectionSearchInvalid(TestcaseBase): log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, - default_search_exp, output_fields=output_fields, - check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "Search doesn't support " - "vector field as output_fields"}) + default_search_exp, output_fields=output_fields) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])