mirror of https://github.com/milvus-io/milvus.git
Add support for getting vectors by ids (#23450)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/23610/head
parent
897ed620e4
commit
092d743917
3
Makefile
3
Makefile
|
@ -318,6 +318,9 @@ rpm: install
|
|||
@cp -r build/rpm/services ~/rpmbuild/BUILD/
|
||||
@QA_RPATHS="$$[ 0x001|0x0002|0x0020 ]" rpmbuild -ba ./build/rpm/milvus.spec
|
||||
|
||||
mock-proxy:
|
||||
mockery --name=ProxyComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy.go --structname=Proxy --with-expecter
|
||||
|
||||
mock-datanode:
|
||||
mockery --name=DataNode --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode.go --with-expecter
|
||||
|
||||
|
|
|
@ -85,6 +85,16 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) {
|
|||
return true;
|
||||
}
|
||||
|
||||
inline DatasetPtr
|
||||
GenIdsDataset(const int64_t count, const int64_t* ids) {
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->SetRows(count);
|
||||
ret_ds->SetDim(1);
|
||||
ret_ds->SetIds(ids);
|
||||
ret_ds->SetIsOwner(false);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
inline DatasetPtr
|
||||
GenResultDataset(const int64_t nq,
|
||||
const int64_t topk,
|
||||
|
|
|
@ -230,6 +230,38 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const bool
|
||||
VectorDiskAnnIndex<T>::HasRawData() const {
|
||||
return index_.HasRawData(GetMetricType());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const std::vector<uint8_t>
|
||||
VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset,
|
||||
const Config& config) const {
|
||||
auto res = index_.GetVectorByIds(*dataset, config);
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(
|
||||
ErrorCodeEnum::UnexpectedError,
|
||||
"failed to get vector, " + MatchKnowhereError(res.error()));
|
||||
}
|
||||
auto index_type = GetIndexType();
|
||||
auto tensor = res.value()->GetTensor();
|
||||
auto row_num = res.value()->GetRows();
|
||||
auto dim = res.value()->GetDim();
|
||||
int64_t data_size;
|
||||
if (is_in_bin_list(index_type)) {
|
||||
data_size = dim / 8 * row_num;
|
||||
} else {
|
||||
data_size = dim * row_num * sizeof(float);
|
||||
}
|
||||
std::vector<uint8_t> raw_data;
|
||||
raw_data.resize(data_size);
|
||||
memcpy(raw_data.data(), tensor, data_size);
|
||||
return raw_data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
VectorDiskAnnIndex<T>::CleanLocalData() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "index/VectorIndex.h"
|
||||
#include "storage/DiskFileManagerImpl.h"
|
||||
|
@ -60,6 +61,13 @@ class VectorDiskAnnIndex : public VectorIndex {
|
|||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) override;
|
||||
|
||||
const bool
|
||||
HasRawData() const override;
|
||||
|
||||
const std::vector<uint8_t>
|
||||
GetVector(const DatasetPtr dataset,
|
||||
const Config& config = {}) const override;
|
||||
|
||||
void
|
||||
CleanLocalData() override;
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
|
||||
#include "knowhere/factory.h"
|
||||
|
@ -50,6 +51,12 @@ class VectorIndex : public IndexBase {
|
|||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) = 0;
|
||||
|
||||
virtual const bool
|
||||
HasRawData() const = 0;
|
||||
|
||||
virtual const std::vector<uint8_t>
|
||||
GetVector(const DatasetPtr dataset, const Config& config = {}) const = 0;
|
||||
|
||||
IndexType
|
||||
GetIndexType() const {
|
||||
return index_type_;
|
||||
|
|
|
@ -145,4 +145,34 @@ VectorMemIndex::Query(const DatasetPtr dataset,
|
|||
return result;
|
||||
}
|
||||
|
||||
const bool
|
||||
VectorMemIndex::HasRawData() const {
|
||||
return index_.HasRawData(GetMetricType());
|
||||
}
|
||||
|
||||
const std::vector<uint8_t>
|
||||
VectorMemIndex::GetVector(const DatasetPtr dataset,
|
||||
const Config& config) const {
|
||||
auto res = index_.GetVectorByIds(*dataset, config);
|
||||
if (!res.has_value()) {
|
||||
PanicCodeInfo(
|
||||
ErrorCodeEnum::UnexpectedError,
|
||||
"failed to get vector, " + MatchKnowhereError(res.error()));
|
||||
}
|
||||
auto index_type = GetIndexType();
|
||||
auto tensor = res.value()->GetTensor();
|
||||
auto row_num = res.value()->GetRows();
|
||||
auto dim = res.value()->GetDim();
|
||||
int64_t data_size;
|
||||
if (is_in_bin_list(index_type)) {
|
||||
data_size = dim / 8 * row_num;
|
||||
} else {
|
||||
data_size = dim * row_num * sizeof(float);
|
||||
}
|
||||
std::vector<uint8_t> raw_data;
|
||||
raw_data.resize(data_size);
|
||||
memcpy(raw_data.data(), tensor, data_size);
|
||||
return raw_data;
|
||||
}
|
||||
|
||||
} // namespace milvus::index
|
||||
|
|
|
@ -51,6 +51,13 @@ class VectorMemIndex : public VectorIndex {
|
|||
const SearchInfo& search_info,
|
||||
const BitsetView& bitset) override;
|
||||
|
||||
const bool
|
||||
HasRawData() const override;
|
||||
|
||||
const std::vector<uint8_t>
|
||||
GetVector(const DatasetPtr dataset,
|
||||
const Config& config = {}) const override;
|
||||
|
||||
protected:
|
||||
Config config_;
|
||||
knowhere::Index<knowhere::IndexNode> index_;
|
||||
|
|
|
@ -223,6 +223,11 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
HasRawData(int64_t field_id) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
int64_t
|
||||
num_chunk() const override;
|
||||
|
|
|
@ -88,6 +88,9 @@ class SegmentInterface {
|
|||
|
||||
virtual SegmentType
|
||||
type() const = 0;
|
||||
|
||||
virtual bool
|
||||
HasRawData(int64_t field_id) const = 0;
|
||||
};
|
||||
|
||||
// internal API for DSL calculation
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "query/ScalarIndex.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnSealed.h"
|
||||
#include "index/Utils.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
|
@ -475,6 +476,35 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
|
|||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
SegmentSealedImpl::get_vector(FieldId field_id,
|
||||
const int64_t* ids,
|
||||
int64_t count) const {
|
||||
auto& filed_meta = schema_->operator[](field_id);
|
||||
AssertInfo(filed_meta.is_vector(), "vector field is not vector type");
|
||||
|
||||
if (get_bit(index_ready_bitset_, field_id)) {
|
||||
AssertInfo(vector_indexings_.is_ready(field_id),
|
||||
"vector index is not ready");
|
||||
auto field_indexing = vector_indexings_.get_field_indexing(field_id);
|
||||
auto vec_index =
|
||||
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
||||
|
||||
auto index_type = vec_index->GetIndexType();
|
||||
auto metric_type = vec_index->GetMetricType();
|
||||
auto has_raw_data = vec_index->HasRawData();
|
||||
|
||||
if (has_raw_data) {
|
||||
auto ids_ds = GenIdsDataset(count, ids);
|
||||
auto& vector = vec_index->GetVector(ids_ds);
|
||||
return segcore::CreateVectorDataArrayFrom(
|
||||
vector.data(), count, filed_meta);
|
||||
}
|
||||
}
|
||||
|
||||
return fill_with_empty(field_id, count);
|
||||
}
|
||||
|
||||
void
|
||||
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
|
||||
if (SystemProperty::Instance().IsSystem(field_id)) {
|
||||
|
@ -666,9 +696,7 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id,
|
|||
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
|
||||
}
|
||||
|
||||
// TODO: knowhere support reverse data from vector index
|
||||
// Now, real data will be filled in data array using chunk manager
|
||||
return fill_with_empty(field_id, count);
|
||||
return get_vector(field_id, seg_offsets, count);
|
||||
}
|
||||
|
||||
Assert(get_bit(field_data_ready_bitset_, field_id));
|
||||
|
@ -783,6 +811,24 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const {
|
|||
}
|
||||
}
|
||||
|
||||
bool
|
||||
SegmentSealedImpl::HasRawData(int64_t field_id) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
auto fieldID = FieldId(field_id);
|
||||
const auto& field_meta = schema_->operator[](fieldID);
|
||||
if (datatype_is_vector(field_meta.get_data_type())) {
|
||||
if (get_bit(index_ready_bitset_, fieldID)) {
|
||||
AssertInfo(vector_indexings_.is_ready(fieldID),
|
||||
"vector index is not ready");
|
||||
auto field_indexing = vector_indexings_.get_field_indexing(fieldID);
|
||||
auto vec_index = dynamic_cast<index::VectorIndex*>(
|
||||
field_indexing->indexing_.get());
|
||||
return vec_index->HasRawData();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
||||
SegmentSealedImpl::search_ids(const IdArray& id_array,
|
||||
Timestamp timestamp) const {
|
||||
|
|
|
@ -61,6 +61,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
|||
return id_;
|
||||
}
|
||||
|
||||
bool
|
||||
HasRawData(int64_t field_id) const override;
|
||||
|
||||
public:
|
||||
int64_t
|
||||
GetMemoryUsageInBytes() const override;
|
||||
|
@ -74,6 +77,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
|||
const Schema&
|
||||
get_schema() const override;
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
get_vector(FieldId field_id, const int64_t* ids, int64_t count) const;
|
||||
|
||||
public:
|
||||
int64_t
|
||||
num_chunk_index(FieldId field_id) const override;
|
||||
|
|
|
@ -164,6 +164,13 @@ GetRealCount(CSegmentInterface c_segment) {
|
|||
return segment->get_real_count();
|
||||
}
|
||||
|
||||
bool
|
||||
HasRawData(CSegmentInterface c_segment, int64_t field_id) {
|
||||
auto segment =
|
||||
reinterpret_cast<milvus::segcore::SegmentInterface*>(c_segment);
|
||||
return segment->HasRawData(field_id);
|
||||
}
|
||||
|
||||
////////////////////////////// interfaces for growing segment //////////////////////////////
|
||||
CStatus
|
||||
Insert(CSegmentInterface c_segment,
|
||||
|
|
|
@ -67,6 +67,9 @@ GetDeletedCount(CSegmentInterface c_segment);
|
|||
int64_t
|
||||
GetRealCount(CSegmentInterface c_segment);
|
||||
|
||||
bool
|
||||
HasRawData(CSegmentInterface c_segment, int64_t field_id);
|
||||
|
||||
////////////////////////////// interfaces for growing segment //////////////////////////////
|
||||
CStatus
|
||||
Insert(CSegmentInterface c_segment,
|
||||
|
|
|
@ -449,6 +449,91 @@ TEST_P(IndexTest, BuildAndQuery) {
|
|||
vec_index->Query(xq_dataset, search_info, nullptr);
|
||||
}
|
||||
|
||||
TEST_P(IndexTest, GetVector) {
|
||||
milvus::index::CreateIndexInfo create_index_info;
|
||||
create_index_info.index_type = index_type;
|
||||
create_index_info.metric_type = metric_type;
|
||||
create_index_info.field_type = vec_field_data_type;
|
||||
index::IndexBasePtr index;
|
||||
|
||||
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
|
||||
#ifdef BUILD_DISK_ANN
|
||||
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||
auto file_manager =
|
||||
std::make_shared<milvus::storage::DiskFileManagerImpl>(
|
||||
field_data_meta, index_meta, storage_config_);
|
||||
index = milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||
create_index_info, file_manager);
|
||||
#endif
|
||||
} else {
|
||||
index = milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||
create_index_info, nullptr);
|
||||
}
|
||||
ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
|
||||
milvus::index::IndexBasePtr new_index;
|
||||
milvus::index::VectorIndex* vec_index = nullptr;
|
||||
|
||||
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
|
||||
#ifdef BUILD_DISK_ANN
|
||||
// TODO ::diskann.query need load first, ugly
|
||||
auto binary_set = index->Serialize(milvus::Config{});
|
||||
index.reset();
|
||||
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||
auto file_manager =
|
||||
std::make_shared<milvus::storage::DiskFileManagerImpl>(
|
||||
field_data_meta, index_meta, storage_config_);
|
||||
new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||
create_index_info, file_manager);
|
||||
|
||||
vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
|
||||
|
||||
std::vector<std::string> index_files;
|
||||
for (auto& binary : binary_set.binary_map_) {
|
||||
index_files.emplace_back(binary.first);
|
||||
}
|
||||
load_conf["index_files"] = index_files;
|
||||
vec_index->Load(binary_set, load_conf);
|
||||
EXPECT_EQ(vec_index->Count(), NB);
|
||||
#endif
|
||||
} else {
|
||||
vec_index = dynamic_cast<milvus::index::VectorIndex*>(index.get());
|
||||
}
|
||||
EXPECT_EQ(vec_index->GetDim(), DIM);
|
||||
EXPECT_EQ(vec_index->Count(), NB);
|
||||
|
||||
if (!vec_index->HasRawData()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto ids_ds = GenRandomIds(NB);
|
||||
auto results = vec_index->GetVector(ids_ds);
|
||||
EXPECT_TRUE(results.size() > 0);
|
||||
if (!is_binary) {
|
||||
std::vector<float> result_vectors(results.size() / (sizeof(float)));
|
||||
memcpy(result_vectors.data(), results.data(), results.size());
|
||||
EXPECT_TRUE(result_vectors.size() == xb_data.size());
|
||||
for (size_t i = 0; i < NB; ++i) {
|
||||
auto id = ids_ds->GetIds()[i];
|
||||
for (size_t j = 0; j < DIM; ++j) {
|
||||
EXPECT_TRUE(result_vectors[i * DIM + j] ==
|
||||
xb_data[id * DIM + j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
EXPECT_TRUE(results.size() == xb_bin_data.size());
|
||||
const auto data_bytes = DIM / 8;
|
||||
for (size_t i = 0; i < NB; ++i) {
|
||||
auto id = ids_ds->GetIds()[i];
|
||||
for (size_t j = 0; j < data_bytes; ++j) {
|
||||
EXPECT_TRUE(results[i * data_bytes + j] ==
|
||||
xb_bin_data[id * data_bytes + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #ifdef BUILD_DISK_ANN
|
||||
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
|
||||
// int64_t NB = 10000;
|
||||
|
|
|
@ -1067,3 +1067,52 @@ TEST(Sealed, RealCount) {
|
|||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(0, segment->get_real_count());
|
||||
}
|
||||
|
||||
TEST(Sealed, GetVector) {
|
||||
auto dim = 16;
|
||||
auto topK = 5;
|
||||
auto N = ROW_COUNT;
|
||||
auto metric_type = knowhere::metric::L2;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto fakevec_id = schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
|
||||
auto counter_id = schema->AddDebugField("counter", DataType::INT64);
|
||||
auto double_id = schema->AddDebugField("double", DataType::DOUBLE);
|
||||
auto nothing_id = schema->AddDebugField("nothing", DataType::INT32);
|
||||
auto str_id = schema->AddDebugField("str", DataType::VARCHAR);
|
||||
schema->AddDebugField("int8", DataType::INT8);
|
||||
schema->AddDebugField("int16", DataType::INT16);
|
||||
schema->AddDebugField("float", DataType::FLOAT);
|
||||
schema->set_primary_field_id(counter_id);
|
||||
|
||||
auto dataset = DataGen(schema, N);
|
||||
|
||||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto indexing = GenVecIndexing(N, dim, fakevec.data());
|
||||
|
||||
auto segment_sealed = CreateSealedSegment(schema);
|
||||
|
||||
LoadIndexInfo vec_info;
|
||||
vec_info.field_id = fakevec_id.get();
|
||||
vec_info.index = std::move(indexing);
|
||||
vec_info.index_params["metric_type"] = knowhere::metric::L2;
|
||||
segment_sealed->LoadIndex(vec_info);
|
||||
|
||||
auto segment = dynamic_cast<SegmentSealedImpl*>(segment_sealed.get());
|
||||
|
||||
auto has = segment->HasRawData(vec_info.field_id);
|
||||
EXPECT_TRUE(has);
|
||||
|
||||
auto ids_ds = GenRandomIds(N);
|
||||
auto result = segment->get_vector(fakevec_id, ids_ds->GetIds(), N);
|
||||
|
||||
auto vector = result.get()->mutable_vectors()->float_vector().data();
|
||||
EXPECT_TRUE(vector.size() == fakevec.size());
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
auto id = ids_ds->GetIds()[i];
|
||||
for (size_t j = 0; j < dim; ++j) {
|
||||
EXPECT_TRUE(vector[i * dim + j] == fakevec[id * dim + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -599,4 +599,15 @@ GenPKs(const std::vector<int64_t>& pks) {
|
|||
return GenPKs(pks.begin(), pks.end());
|
||||
}
|
||||
|
||||
inline std::shared_ptr<knowhere::DataSet>
|
||||
GenRandomIds(int rows, int64_t seed = 42) {
|
||||
std::mt19937 g(seed);
|
||||
auto* ids = new int64_t[rows];
|
||||
for (int i = 0; i < rows; ++i) ids[i] = i;
|
||||
std::shuffle(ids, ids + rows, g);
|
||||
auto ids_ds = GenIdsDataset(rows, ids);
|
||||
ids_ds->SetIsOwner(true);
|
||||
return ids_ds;
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -2411,9 +2411,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
shardMgr: node.shardMgr,
|
||||
qc: node.queryCoord,
|
||||
node: node,
|
||||
}
|
||||
|
||||
travelTs := request.TravelTimestamp
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
|
@ -492,7 +493,10 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
|
|||
case *schemapb.IDs_IntId:
|
||||
idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
|
||||
case *schemapb.IDs_StrId:
|
||||
idsStr = strings.Trim(strings.Join(ids.GetStrId().GetData(), ", "), "[]")
|
||||
strs := lo.Map(ids.GetStrId().GetData(), func(str string, _ int) string {
|
||||
return fmt.Sprintf("\"%s\"", str)
|
||||
})
|
||||
idsStr = strings.Trim(strings.Join(strs, ", "), "[]")
|
||||
}
|
||||
|
||||
return fieldName + " in [ " + idsStr + " ]"
|
||||
|
|
|
@ -869,3 +869,28 @@ func Test_queryTask_createPlan(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryTask_IDs2Expr(t *testing.T) {
|
||||
fieldName := "pk"
|
||||
intIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
stringIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
}
|
||||
idExpr := IDs2Expr(fieldName, intIDs)
|
||||
expectIDExpr := "pk in [ 1, 2, 3, 4, 5 ]"
|
||||
assert.Equal(t, expectIDExpr, idExpr)
|
||||
|
||||
strExpr := IDs2Expr(fieldName, stringIDs)
|
||||
expectStrExpr := "pk in [ \"a\", \"b\", \"c\" ]"
|
||||
assert.Equal(t, expectStrExpr, strExpr)
|
||||
}
|
||||
|
|
|
@ -3,11 +3,13 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -25,6 +27,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
|
@ -34,6 +37,12 @@ import (
|
|||
const (
|
||||
SearchTaskName = "SearchTask"
|
||||
SearchLevelKey = "level"
|
||||
|
||||
// requeryThreshold is the estimated threshold for the size of the search results.
|
||||
// If the number of estimated search results exceeds this threshold,
|
||||
// a second query request will be initiated to retrieve output fields data.
|
||||
// In this case, the first search will not return any output field from QueryNodes.
|
||||
requeryThreshold = 0.5 * 1024 * 1024
|
||||
)
|
||||
|
||||
type searchTask struct {
|
||||
|
@ -43,11 +52,12 @@ type searchTask struct {
|
|||
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.SearchRequest
|
||||
qc types.QueryCoord
|
||||
|
||||
tr *timerecord.TimeRecorder
|
||||
collectionName string
|
||||
channelNum int32
|
||||
schema *schemapb.CollectionSchema
|
||||
requery bool
|
||||
|
||||
offset int64
|
||||
resultBuf chan *internalpb.SearchResults
|
||||
|
@ -55,6 +65,9 @@ type searchTask struct {
|
|||
|
||||
searchShardPolicy pickShardPolicy
|
||||
shardMgr *shardClientMgr
|
||||
|
||||
qc types.QueryCoord
|
||||
node types.ProxyComponent
|
||||
}
|
||||
|
||||
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||
|
@ -164,11 +177,7 @@ func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string)
|
|||
hitField := false
|
||||
for _, field := range schema.GetFields() {
|
||||
if field.Name == name {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
||||
return nil, errors.New("search doesn't support vector field as output_fields")
|
||||
}
|
||||
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
|
||||
|
||||
hitField = true
|
||||
break
|
||||
}
|
||||
|
@ -255,6 +264,24 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||
|
||||
// Manually update nq if not set.
|
||||
nq, err := getNq(t.request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
|
||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
|
||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
|
@ -278,17 +305,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||
|
||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
plan.OutputFieldIds = outputFieldIDs
|
||||
|
||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if estimateSize >= requeryThreshold {
|
||||
t.requery = true
|
||||
plan.OutputFieldIds = nil
|
||||
}
|
||||
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -319,17 +350,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
|
||||
t.SearchRequest.Dsl = t.request.Dsl
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
// Manually update nq if not set.
|
||||
nq, err := getNq(t.request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
|
||||
log.Ctx(ctx).Debug("search PreExecute done.",
|
||||
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
|
||||
|
@ -435,6 +455,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
t.result.CollectionName = t.collectionName
|
||||
t.fillInFieldInfo()
|
||||
|
||||
if t.requery {
|
||||
err = t.Requery()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug("Search post execute done",
|
||||
zap.Int64("collection", t.GetCollectionID()),
|
||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()))
|
||||
|
@ -480,6 +507,93 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
|
||||
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
|
||||
})
|
||||
// Currently, we get vectors by requery. Once we support getting vectors from search,
|
||||
// searches with small result size could no longer need requery.
|
||||
if len(vectorOutputFields) > 0 {
|
||||
return math.MaxInt64, nil
|
||||
}
|
||||
// If no vector field as output, no need to requery.
|
||||
return 0, nil
|
||||
|
||||
//outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
// return lo.Contains(t.request.GetOutputFields(), field.GetName())
|
||||
//})
|
||||
//sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields})
|
||||
//if err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
//return int64(sizePerRecord) * nq * topK, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) Requery() error {
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ids := t.result.GetResults().GetIds()
|
||||
expr := IDs2Expr(pkField.GetName(), ids)
|
||||
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
},
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
Expr: expr,
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
TravelTimestamp: t.request.GetTravelTimestamp(),
|
||||
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||
QueryParams: t.request.GetSearchParams(),
|
||||
}
|
||||
queryResult, err := t.node.Query(t.ctx, queryReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return merr.Error(queryResult.GetStatus())
|
||||
}
|
||||
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
|
||||
// We should reorganize query results to keep the order of original queried ids. For example:
|
||||
// ===========================================
|
||||
// 3 2 5 4 1 (query ids)
|
||||
// ||
|
||||
// || (query)
|
||||
// \/
|
||||
// 4 3 5 1 2 (result ids)
|
||||
// v4 v3 v5 v1 v2 (result vectors)
|
||||
// ||
|
||||
// || (reorganize)
|
||||
// \/
|
||||
// 3 2 5 4 1 (result ids)
|
||||
// v3 v2 v5 v4 v1 (result vectors)
|
||||
// ===========================================
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offsets := make(map[any]int)
|
||||
for i := 0; i < typeutil.GetDataSize(pkFieldData); i++ {
|
||||
pk := typeutil.GetData(pkFieldData, i)
|
||||
offsets[pk] = i
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
|
||||
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
|
||||
id := typeutil.GetPK(ids, int64(i))
|
||||
if _, ok := offsets[id]; !ok {
|
||||
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID())
|
||||
}
|
||||
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||
t.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
@ -268,7 +269,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
|||
|
||||
// contain vector field
|
||||
task.request.OutputFields = []string{testFloatVecField}
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1959,3 +1960,287 @@ func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
|||
}
|
||||
return &result
|
||||
}
|
||||
|
||||
func TestSearchTask_Requery(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
const (
|
||||
dim = 128
|
||||
rows = 5
|
||||
collection = "test-requery"
|
||||
|
||||
pkField = "pk"
|
||||
vecField = "vec"
|
||||
)
|
||||
|
||||
ids := make([]int64, rows)
|
||||
for i := range ids {
|
||||
ids[i] = int64(i)
|
||||
}
|
||||
|
||||
t.Run("Test normal", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: pkField,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
newFloatVectorFieldData(vecField, rows, dim),
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test no primary key", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{}
|
||||
node := mocks.NewProxy(t)
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test requery failed 1", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, fmt.Errorf("mock err 1"))
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test requery failed 2", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mock err 2",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test get pk filed data failed", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
}, nil)
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test incomplete query result", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: pkField,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: ids[:len(ids)-1],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
newFloatVectorFieldData(vecField, rows, dim),
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test postExecute with requery failed", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, fmt.Errorf("mock err 1"))
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
requery: true,
|
||||
schema: schema,
|
||||
resultBuf: make(chan *internalpb.SearchResults, 10),
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
scores := make([]float32, rows)
|
||||
for i := range scores {
|
||||
scores[i] = float32(i)
|
||||
}
|
||||
partialResultData := &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
Scores: scores,
|
||||
}
|
||||
bytes, err := proto.Marshal(partialResultData)
|
||||
assert.NoError(t, err)
|
||||
qt.resultBuf <- &internalpb.SearchResults{
|
||||
SlicedBlob: bytes,
|
||||
}
|
||||
|
||||
err = qt.PostExecute(ctx)
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -282,6 +282,13 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool {
|
|||
return fieldInfo.IndexInfo != nil && fieldInfo.IndexInfo.EnableIndex
|
||||
}
|
||||
|
||||
func (s *LocalSegment) HasRawData(fieldID int64) bool {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
ret := C.HasRawData(s.ptr, C.int64_t(fieldID))
|
||||
return bool(ret)
|
||||
}
|
||||
|
||||
func (s *LocalSegment) Indexes() []*IndexedFieldInfo {
|
||||
var result []*IndexedFieldInfo
|
||||
s.fieldIndexes.Range(func(key int64, value *IndexedFieldInfo) bool {
|
||||
|
@ -463,10 +470,18 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
|
|||
)
|
||||
|
||||
for _, fieldData := range result.FieldsData {
|
||||
// If the vector field doesn't have indexed. Vector data is in memory for
|
||||
// brute force search. No need to download data from remote.
|
||||
if fieldData.GetType() != schemapb.DataType_FloatVector && fieldData.GetType() != schemapb.DataType_BinaryVector ||
|
||||
!s.ExistIndex(fieldData.FieldId) {
|
||||
// If the field is not vector field, no need to download data from remote.
|
||||
if !typeutil.IsVectorType(fieldData.GetType()) {
|
||||
continue
|
||||
}
|
||||
// If the vector field doesn't have indexed, vector data is in memory
|
||||
// for brute force search, no need to download data from remote.
|
||||
if !s.ExistIndex(fieldData.FieldId) {
|
||||
continue
|
||||
}
|
||||
// If the index has raw data, vector data could be obtained from index,
|
||||
// no need to download data from remote.
|
||||
if s.HasRawData(fieldData.FieldId) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -117,6 +117,13 @@ func (suite *SegmentSuite) TestDelete() {
|
|||
suite.Equal(rowNum, suite.growing.InsertCount())
|
||||
}
|
||||
|
||||
func (suite *SegmentSuite) TestHasRawData() {
|
||||
has := suite.growing.HasRawData(simpleFloatVecField.id)
|
||||
suite.True(has)
|
||||
has = suite.sealed.HasRawData(simpleFloatVecField.id)
|
||||
suite.True(has)
|
||||
}
|
||||
|
||||
func TestSegment(t *testing.T) {
|
||||
suite.Run(t, new(SegmentSuite))
|
||||
}
|
||||
|
|
|
@ -735,6 +735,30 @@ func GetSizeOfIDs(data *schemapb.IDs) int {
|
|||
return result
|
||||
}
|
||||
|
||||
func GetDataSize(fieldData *schemapb.FieldData) int {
|
||||
switch fieldData.GetType() {
|
||||
case schemapb.DataType_Bool:
|
||||
return len(fieldData.GetScalars().GetBoolData().GetData())
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
return len(fieldData.GetScalars().GetIntData().GetData())
|
||||
case schemapb.DataType_Int64:
|
||||
return len(fieldData.GetScalars().GetLongData().GetData())
|
||||
case schemapb.DataType_Float:
|
||||
return len(fieldData.GetScalars().GetFloatData().GetData())
|
||||
case schemapb.DataType_Double:
|
||||
return len(fieldData.GetScalars().GetDoubleData().GetData())
|
||||
case schemapb.DataType_String:
|
||||
return len(fieldData.GetScalars().GetStringData().GetData())
|
||||
case schemapb.DataType_VarChar:
|
||||
return len(fieldData.GetScalars().GetStringData().GetData())
|
||||
case schemapb.DataType_FloatVector:
|
||||
return len(fieldData.GetVectors().GetFloatVector().GetData())
|
||||
case schemapb.DataType_BinaryVector:
|
||||
return len(fieldData.GetVectors().GetBinaryVector())
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func IsPrimaryFieldType(dataType schemapb.DataType) bool {
|
||||
if dataType == schemapb.DataType_Int64 || dataType == schemapb.DataType_VarChar {
|
||||
return true
|
||||
|
@ -756,6 +780,33 @@ func GetPK(data *schemapb.IDs, idx int64) interface{} {
|
|||
return nil
|
||||
}
|
||||
|
||||
func GetData(field *schemapb.FieldData, idx int) interface{} {
|
||||
switch field.GetType() {
|
||||
case schemapb.DataType_Bool:
|
||||
return field.GetScalars().GetBoolData().GetData()[idx]
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
return field.GetScalars().GetIntData().GetData()[idx]
|
||||
case schemapb.DataType_Int64:
|
||||
return field.GetScalars().GetLongData().GetData()[idx]
|
||||
case schemapb.DataType_Float:
|
||||
return field.GetScalars().GetFloatData().GetData()[idx]
|
||||
case schemapb.DataType_Double:
|
||||
return field.GetScalars().GetDoubleData().GetData()[idx]
|
||||
case schemapb.DataType_String:
|
||||
return field.GetScalars().GetStringData().GetData()[idx]
|
||||
case schemapb.DataType_VarChar:
|
||||
return field.GetScalars().GetStringData().GetData()[idx]
|
||||
case schemapb.DataType_FloatVector:
|
||||
dim := int(field.GetVectors().GetDim())
|
||||
return field.GetVectors().GetFloatVector().GetData()[idx*dim : (idx+1)*dim]
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dim := int(field.GetVectors().GetDim())
|
||||
dataBytes := dim / 8
|
||||
return field.GetVectors().GetBinaryVector()[idx*dataBytes : (idx+1)*dataBytes]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AppendPKs(pks *schemapb.IDs, pk interface{}) {
|
||||
switch realPK := pk.(type) {
|
||||
case int64:
|
||||
|
|
|
@ -470,6 +470,21 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType,
|
|||
},
|
||||
FieldId: fieldID,
|
||||
}
|
||||
case schemapb.DataType_String:
|
||||
fieldData = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_String,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: fieldValue.([]string),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldId: fieldID,
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
fieldData = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
|
@ -990,3 +1005,94 @@ func TestCalcColumnSize(t *testing.T) {
|
|||
assert.Equal(t, expected, size, field.GetName())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDataAndGetDataSize(t *testing.T) {
|
||||
const (
|
||||
Dim = 8
|
||||
fieldName = "filed-0"
|
||||
fieldID = 0
|
||||
)
|
||||
|
||||
BoolArray := []bool{true, false}
|
||||
Int8Array := []int8{1, 2}
|
||||
Int16Array := []int16{3, 4}
|
||||
Int32Array := []int32{5, 6}
|
||||
Int64Array := []int64{11, 22}
|
||||
FloatArray := []float32{1.0, 2.0}
|
||||
DoubleArray := []float64{11.0, 22.0}
|
||||
VarCharArray := []string{"a", "b"}
|
||||
StringArray := []string{"c", "d"}
|
||||
BinaryVector := []byte{0x12, 0x34}
|
||||
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
|
||||
|
||||
boolData := genFieldData(fieldName, fieldID, schemapb.DataType_Bool, BoolArray, 1)
|
||||
int8Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int8, Int8Array, 1)
|
||||
int16Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int16, Int16Array, 1)
|
||||
int32Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int32, Int32Array, 1)
|
||||
int64Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int64, Int64Array, 1)
|
||||
floatData := genFieldData(fieldName, fieldID, schemapb.DataType_Float, FloatArray, 1)
|
||||
doubleData := genFieldData(fieldName, fieldID, schemapb.DataType_Double, DoubleArray, 1)
|
||||
varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1)
|
||||
stringData := genFieldData(fieldName, fieldID, schemapb.DataType_String, StringArray, 1)
|
||||
binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim)
|
||||
floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim)
|
||||
invalidData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_None,
|
||||
}
|
||||
|
||||
t.Run("test GetDataSize", func(t *testing.T) {
|
||||
boolDataRes := GetDataSize(boolData)
|
||||
int8DataRes := GetDataSize(int8Data)
|
||||
int16DataRes := GetDataSize(int16Data)
|
||||
int32DataRes := GetDataSize(int32Data)
|
||||
int64DataRes := GetDataSize(int64Data)
|
||||
floatDataRes := GetDataSize(floatData)
|
||||
doubleDataRes := GetDataSize(doubleData)
|
||||
varCharDataRes := GetDataSize(varCharData)
|
||||
stringDataRes := GetDataSize(stringData)
|
||||
binVecDataRes := GetDataSize(binVecData)
|
||||
floatVecDataRes := GetDataSize(floatVecData)
|
||||
invalidDataRes := GetDataSize(invalidData)
|
||||
|
||||
assert.Equal(t, 2, boolDataRes)
|
||||
assert.Equal(t, 2, int8DataRes)
|
||||
assert.Equal(t, 2, int16DataRes)
|
||||
assert.Equal(t, 2, int32DataRes)
|
||||
assert.Equal(t, 2, int64DataRes)
|
||||
assert.Equal(t, 2, floatDataRes)
|
||||
assert.Equal(t, 2, doubleDataRes)
|
||||
assert.Equal(t, 2, varCharDataRes)
|
||||
assert.Equal(t, 2, stringDataRes)
|
||||
assert.Equal(t, 2*Dim/8, binVecDataRes)
|
||||
assert.Equal(t, 2*Dim, floatVecDataRes)
|
||||
assert.Equal(t, 0, invalidDataRes)
|
||||
})
|
||||
|
||||
t.Run("test GetData", func(t *testing.T) {
|
||||
boolDataRes := GetData(boolData, 0)
|
||||
int8DataRes := GetData(int8Data, 0)
|
||||
int16DataRes := GetData(int16Data, 0)
|
||||
int32DataRes := GetData(int32Data, 0)
|
||||
int64DataRes := GetData(int64Data, 0)
|
||||
floatDataRes := GetData(floatData, 0)
|
||||
doubleDataRes := GetData(doubleData, 0)
|
||||
varCharDataRes := GetData(varCharData, 0)
|
||||
stringDataRes := GetData(stringData, 0)
|
||||
binVecDataRes := GetData(binVecData, 0)
|
||||
floatVecDataRes := GetData(floatVecData, 0)
|
||||
invalidDataRes := GetData(invalidData, 0)
|
||||
|
||||
assert.Equal(t, BoolArray[0], boolDataRes)
|
||||
assert.Equal(t, int32(Int8Array[0]), int8DataRes)
|
||||
assert.Equal(t, int32(Int16Array[0]), int16DataRes)
|
||||
assert.Equal(t, Int32Array[0], int32DataRes)
|
||||
assert.Equal(t, Int64Array[0], int64DataRes)
|
||||
assert.Equal(t, FloatArray[0], floatDataRes)
|
||||
assert.Equal(t, DoubleArray[0], doubleDataRes)
|
||||
assert.Equal(t, VarCharArray[0], varCharDataRes)
|
||||
assert.Equal(t, StringArray[0], stringDataRes)
|
||||
assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes)
|
||||
assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes)
|
||||
assert.Nil(t, invalidDataRes)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -51,72 +51,24 @@ const (
|
|||
// 5, load
|
||||
// 6, search
|
||||
func TestBulkInsert(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||
c, err := StartMiniCluster(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = c.Start()
|
||||
assert.NoError(t, err)
|
||||
defer c.Stop()
|
||||
defer func() {
|
||||
err = c.Stop()
|
||||
assert.NoError(t, err)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
prefix := "TestBulkInsert"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
int64Field := "int64"
|
||||
floatVecField := "embeddings"
|
||||
scalarField := "image_path"
|
||||
floatVecField := floatVecField
|
||||
dim := 128
|
||||
|
||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
||||
pk := &schemapb.FieldSchema{
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "pk",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
AutoID: true,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
scalar := &schemapb.FieldSchema{
|
||||
Name: scalarField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "max_length",
|
||||
Value: "65535",
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
pk,
|
||||
fVec,
|
||||
scalar,
|
||||
},
|
||||
}
|
||||
}
|
||||
schema := constructCollectionSchema()
|
||||
schema := constructSchema(collectionName, dim, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -207,28 +159,7 @@ func TestBulkInsert(t *testing.T) {
|
|||
CollectionName: collectionName,
|
||||
FieldName: floatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: "index_type",
|
||||
Value: "HNSW",
|
||||
},
|
||||
{
|
||||
Key: "M",
|
||||
Value: "64",
|
||||
},
|
||||
{
|
||||
Key: "efConstruction",
|
||||
Value: "512",
|
||||
},
|
||||
},
|
||||
ExtraParams: constructIndexParam(dim, IndexHNSW, distance.L2),
|
||||
})
|
||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||
|
@ -246,30 +177,17 @@ func TestBulkInsert(t *testing.T) {
|
|||
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
||||
}
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
||||
for {
|
||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
if err != nil {
|
||||
panic("GetLoadingProgress fail")
|
||||
}
|
||||
if loadProgress.GetProgress() == 100 {
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
waitingForLoad(ctx, c, collectionName)
|
||||
|
||||
// search
|
||||
expr := fmt.Sprintf("%s > 0", "int64")
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
nprobe := 10
|
||||
params := make(map[string]int)
|
||||
params["nprobe"] = nprobe
|
||||
|
||||
params := getSearchParams(IndexHNSW, distance.L2)
|
||||
searchReq := constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
|
|
@ -0,0 +1,366 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type TestGetVectorSuite struct {
|
||||
suite.Suite
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
cluster *MiniCluster
|
||||
|
||||
// test params
|
||||
nq int
|
||||
topK int
|
||||
indexType string
|
||||
metricType string
|
||||
pkType schemapb.DataType
|
||||
vecType schemapb.DataType
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) SetupTest() {
|
||||
suite.ctx, suite.cancel = context.WithTimeout(context.Background(), time.Second*600)
|
||||
|
||||
var err error
|
||||
suite.cluster, err = StartMiniCluster(suite.ctx)
|
||||
suite.Require().NoError(err)
|
||||
err = suite.cluster.Start()
|
||||
suite.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) run() {
|
||||
collection := fmt.Sprintf("TestGetVector_%d_%d_%s_%s_%s",
|
||||
suite.nq, suite.topK, suite.indexType, suite.metricType, funcutil.GenRandomStr())
|
||||
|
||||
const (
|
||||
NB = 10000
|
||||
dim = 128
|
||||
)
|
||||
|
||||
pkFieldName := "pkField"
|
||||
vecFieldName := "vecField"
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 100,
|
||||
Name: pkFieldName,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: suite.pkType,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "max_length",
|
||||
Value: "100",
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: vecFieldName,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: suite.vecType,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: fmt.Sprintf("%d", dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
}
|
||||
schema := constructSchema(collection, dim, false, pk, fVec)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
createCollectionStatus, err := suite.cluster.proxy.CreateCollection(suite.ctx, &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: collection,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: 2,
|
||||
})
|
||||
suite.Require().NoError(err)
|
||||
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
||||
fieldsData := make([]*schemapb.FieldData, 0)
|
||||
if suite.pkType == schemapb.DataType_Int64 {
|
||||
fieldsData = append(fieldsData, newInt64FieldData(pkFieldName, NB))
|
||||
} else {
|
||||
fieldsData = append(fieldsData, newStringFieldData(pkFieldName, NB))
|
||||
}
|
||||
var vecFieldData *schemapb.FieldData
|
||||
if suite.vecType == schemapb.DataType_FloatVector {
|
||||
vecFieldData = newFloatVectorFieldData(vecFieldName, NB, dim)
|
||||
} else {
|
||||
vecFieldData = newBinaryVectorFieldData(vecFieldName, NB, dim)
|
||||
}
|
||||
fieldsData = append(fieldsData, vecFieldData)
|
||||
hashKeys := generateHashKeys(NB)
|
||||
_, err = suite.cluster.proxy.Insert(suite.ctx, &milvuspb.InsertRequest{
|
||||
CollectionName: collection,
|
||||
FieldsData: fieldsData,
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(NB),
|
||||
})
|
||||
suite.Require().NoError(err)
|
||||
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
||||
// flush
|
||||
flushResp, err := suite.cluster.proxy.Flush(suite.ctx, &milvuspb.FlushRequest{
|
||||
CollectionNames: []string{collection},
|
||||
})
|
||||
suite.Require().NoError(err)
|
||||
segmentIDs, has := flushResp.GetCollSegIDs()[collection]
|
||||
ids := segmentIDs.GetData()
|
||||
suite.Require().NotEmpty(segmentIDs)
|
||||
suite.Require().True(has)
|
||||
|
||||
segments, err := suite.cluster.metaWatcher.ShowSegments()
|
||||
suite.Require().NoError(err)
|
||||
suite.Require().NotEmpty(segments)
|
||||
|
||||
waitingForFlush(suite.ctx, suite.cluster, ids)
|
||||
|
||||
// create index
|
||||
_, err = suite.cluster.proxy.CreateIndex(suite.ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collection,
|
||||
FieldName: vecFieldName,
|
||||
IndexName: "_default",
|
||||
ExtraParams: constructIndexParam(dim, suite.indexType, suite.metricType),
|
||||
})
|
||||
suite.Require().NoError(err)
|
||||
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
||||
// load
|
||||
_, err = suite.cluster.proxy.LoadCollection(suite.ctx, &milvuspb.LoadCollectionRequest{
|
||||
CollectionName: collection,
|
||||
})
|
||||
suite.Require().NoError(err)
|
||||
suite.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
waitingForLoad(suite.ctx, suite.cluster, collection)
|
||||
|
||||
// search
|
||||
nq := suite.nq
|
||||
topk := suite.topK
|
||||
|
||||
outputFields := []string{vecFieldName}
|
||||
params := getSearchParams(suite.indexType, suite.metricType)
|
||||
searchReq := constructSearchRequest("", collection, "",
|
||||
vecFieldName, suite.vecType, outputFields, suite.metricType, params, nq, dim, topk, -1)
|
||||
|
||||
searchResp, err := suite.cluster.proxy.Search(suite.ctx, searchReq)
|
||||
suite.Require().NoError(err)
|
||||
suite.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
||||
result := searchResp.GetResults()
|
||||
if suite.pkType == schemapb.DataType_Int64 {
|
||||
suite.Require().Len(result.GetIds().GetIntId().GetData(), nq*topk)
|
||||
} else {
|
||||
suite.Require().Len(result.GetIds().GetStrId().GetData(), nq*topk)
|
||||
}
|
||||
suite.Require().Len(result.GetScores(), nq*topk)
|
||||
suite.Require().GreaterOrEqual(len(result.GetFieldsData()), 1)
|
||||
var vecFieldIndex = -1
|
||||
for i, fieldData := range result.GetFieldsData() {
|
||||
if typeutil.IsVectorType(fieldData.GetType()) {
|
||||
vecFieldIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
suite.Require().EqualValues(nq, result.GetNumQueries())
|
||||
suite.Require().EqualValues(topk, result.GetTopK())
|
||||
|
||||
// check output vectors
|
||||
if suite.vecType == schemapb.DataType_FloatVector {
|
||||
suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData(), nq*topk*dim)
|
||||
rawData := vecFieldData.GetVectors().GetFloatVector().GetData()
|
||||
resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloatVector().GetData()
|
||||
if suite.pkType == schemapb.DataType_Int64 {
|
||||
for i, id := range result.GetIds().GetIntId().GetData() {
|
||||
expect := rawData[int(id)*dim : (int(id)+1)*dim]
|
||||
actual := resData[i*dim : (i+1)*dim]
|
||||
suite.Require().ElementsMatch(expect, actual)
|
||||
}
|
||||
} else {
|
||||
for i, idStr := range result.GetIds().GetStrId().GetData() {
|
||||
id, err := strconv.Atoi(idStr)
|
||||
suite.Require().NoError(err)
|
||||
expect := rawData[id*dim : (id+1)*dim]
|
||||
actual := resData[i*dim : (i+1)*dim]
|
||||
suite.Require().ElementsMatch(expect, actual)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
suite.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector(), nq*topk*dim/8)
|
||||
rawData := vecFieldData.GetVectors().GetBinaryVector()
|
||||
resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector()
|
||||
if suite.pkType == schemapb.DataType_Int64 {
|
||||
for i, id := range result.GetIds().GetIntId().GetData() {
|
||||
dataBytes := dim / 8
|
||||
for j := 0; j < dataBytes; j++ {
|
||||
expect := rawData[int(id)*dataBytes+j]
|
||||
actual := resData[i*dataBytes+j]
|
||||
suite.Require().Equal(expect, actual)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, idStr := range result.GetIds().GetStrId().GetData() {
|
||||
dataBytes := dim / 8
|
||||
id, err := strconv.Atoi(idStr)
|
||||
suite.Require().NoError(err)
|
||||
for j := 0; j < dataBytes; j++ {
|
||||
expect := rawData[id*dataBytes+j]
|
||||
actual := resData[i*dataBytes+j]
|
||||
suite.Require().Equal(expect, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status, err := suite.cluster.proxy.DropCollection(suite.ctx, &milvuspb.DropCollectionRequest{
|
||||
CollectionName: collection,
|
||||
})
|
||||
suite.Require().NoError(err)
|
||||
suite.Require().Equal(status.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_FLAT() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexFaissIDMap
|
||||
suite.metricType = distance.L2
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexFaissIvfFlat
|
||||
suite.metricType = distance.L2
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_IVF_PQ() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexFaissIvfPQ
|
||||
suite.metricType = distance.L2
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexFaissIvfSQ8
|
||||
suite.metricType = distance.L2
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_HNSW() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexHNSW
|
||||
suite.metricType = distance.L2
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_IP() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexHNSW
|
||||
suite.metricType = distance.IP
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_StringPK() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexHNSW
|
||||
suite.metricType = distance.L2
|
||||
suite.pkType = schemapb.DataType_VarChar
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_BinaryVector() {
|
||||
suite.nq = 10
|
||||
suite.topK = 10
|
||||
suite.indexType = IndexFaissBinIvfFlat
|
||||
suite.metricType = distance.JACCARD
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_BinaryVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
func (suite *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
|
||||
suite.nq = 10000
|
||||
suite.topK = 200
|
||||
suite.indexType = IndexHNSW
|
||||
suite.metricType = distance.L2
|
||||
suite.pkType = schemapb.DataType_Int64
|
||||
suite.vecType = schemapb.DataType_FloatVector
|
||||
suite.run()
|
||||
}
|
||||
|
||||
//func (suite *TestGetVectorSuite) TestGetVector_DISKANN() {
|
||||
// suite.nq = 10
|
||||
// suite.topK = 10
|
||||
// suite.indexType = IndexDISKANN
|
||||
// suite.metricType = distance.L2
|
||||
// suite.pkType = schemapb.DataType_Int64
|
||||
// suite.vecType = schemapb.DataType_FloatVector
|
||||
// suite.run()
|
||||
//}
|
||||
|
||||
func (suite *TestGetVectorSuite) TearDownTest() {
|
||||
err := suite.cluster.Stop()
|
||||
suite.Require().NoError(err)
|
||||
suite.cancel()
|
||||
}
|
||||
|
||||
func TestGetVector(t *testing.T) {
|
||||
suite.Run(t, new(TestGetVectorSuite))
|
||||
}
|
|
@ -17,17 +17,11 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
@ -42,59 +36,27 @@ import (
|
|||
)
|
||||
|
||||
func TestHelloMilvus(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||
defer cancel()
|
||||
c, err := StartMiniCluster(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = c.Start()
|
||||
assert.NoError(t, err)
|
||||
defer c.Stop()
|
||||
defer func() {
|
||||
err = c.Stop()
|
||||
assert.NoError(t, err)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
prefix := "TestHelloMilvus"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
rowNum := 3000
|
||||
const (
|
||||
dim = 128
|
||||
dbName = ""
|
||||
rowNum = 3000
|
||||
)
|
||||
|
||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
AutoID: true,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
pk,
|
||||
fVec,
|
||||
},
|
||||
}
|
||||
}
|
||||
schema := constructCollectionSchema()
|
||||
collectionName := "TestHelloMilvus" + funcutil.GenRandomStr()
|
||||
|
||||
schema := constructSchema(collectionName, dim, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -137,6 +99,7 @@ func TestHelloMilvus(t *testing.T) {
|
|||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||
ids := segmentIDs.GetData()
|
||||
assert.NotEmpty(t, segmentIDs)
|
||||
assert.True(t, has)
|
||||
|
||||
segments, err := c.metaWatcher.ShowSegments()
|
||||
assert.NoError(t, err)
|
||||
|
@ -144,52 +107,14 @@ func TestHelloMilvus(t *testing.T) {
|
|||
for _, segment := range segments {
|
||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||
}
|
||||
|
||||
if has && len(ids) > 0 {
|
||||
flushed := func() bool {
|
||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
||||
SegmentIDs: ids,
|
||||
})
|
||||
if err != nil {
|
||||
//panic(errors.New("GetFlushState failed"))
|
||||
return false
|
||||
}
|
||||
return resp.GetFlushed()
|
||||
}
|
||||
for !flushed() {
|
||||
// respect context deadline/cancel
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
panic(errors.New("deadline exceeded"))
|
||||
default:
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
waitingForFlush(ctx, c, ids)
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: floatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: "index_type",
|
||||
Value: "IVF_FLAT",
|
||||
},
|
||||
{
|
||||
Key: "nlist",
|
||||
Value: strconv.Itoa(10),
|
||||
},
|
||||
},
|
||||
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2),
|
||||
})
|
||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||
|
@ -207,30 +132,17 @@ func TestHelloMilvus(t *testing.T) {
|
|||
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
||||
}
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
||||
for {
|
||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
if err != nil {
|
||||
panic("GetLoadingProgress fail")
|
||||
}
|
||||
if loadProgress.GetProgress() == 100 {
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
waitingForLoad(ctx, c, collectionName)
|
||||
|
||||
// search
|
||||
expr := fmt.Sprintf("%s > 0", "int64")
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
nprobe := 10
|
||||
params := make(map[string]int)
|
||||
params["nprobe"] = nprobe
|
||||
|
||||
params := getSearchParams(IndexFaissIvfFlat, distance.L2)
|
||||
searchReq := constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -242,155 +154,3 @@ func TestHelloMilvus(t *testing.T) {
|
|||
|
||||
log.Info("TestHelloMilvus succeed")
|
||||
}
|
||||
|
||||
const (
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
NQKey = "nq"
|
||||
MetricTypeKey = "metric_type"
|
||||
SearchParamsKey = "params"
|
||||
RoundDecimalKey = "round_decimal"
|
||||
OffsetKey = "offset"
|
||||
LimitKey = "limit"
|
||||
)
|
||||
|
||||
func constructSearchRequest(
|
||||
dbName, collectionName string,
|
||||
expr string,
|
||||
floatVecField string,
|
||||
metricType string,
|
||||
params map[string]int,
|
||||
nq, dim, topk, roundDecimal int,
|
||||
) *milvuspb.SearchRequest {
|
||||
b, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
plg := constructPlaceholderGroup(nq, dim)
|
||||
plgBs, err := proto.Marshal(plg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &milvuspb.SearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: nil,
|
||||
Dsl: expr,
|
||||
PlaceholderGroup: plgBs,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: nil,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metricType,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
Value: string(b),
|
||||
},
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: strconv.Itoa(topk),
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Value: strconv.Itoa(roundDecimal),
|
||||
},
|
||||
},
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func constructPlaceholderGroup(
|
||||
nq, dim int,
|
||||
) *commonpb.PlaceholderGroup {
|
||||
values := make([][]byte, 0, nq)
|
||||
for i := 0; i < nq; i++ {
|
||||
bs := make([]byte, 0, dim*4)
|
||||
for j := 0; j < dim; j++ {
|
||||
var buffer bytes.Buffer
|
||||
f := rand.Float32()
|
||||
err := binary.Write(&buffer, common.Endian, f)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bs = append(bs, buffer.Bytes()...)
|
||||
}
|
||||
values = append(values, bs)
|
||||
}
|
||||
|
||||
return &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: values,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: generateFloatVectors(numRows, dim),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newInt64PrimaryKey(fieldName string, numRows int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: generateInt64Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func generateFloatVectors(numRows, dim int) []float32 {
|
||||
total := numRows * dim
|
||||
ret := make([]float32, 0, total)
|
||||
for i := 0; i < total; i++ {
|
||||
ret = append(ret, rand.Float32())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateInt64Array(numRows int) []int64 {
|
||||
ret := make([]int64, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, int64(rand.Int()))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateHashKeys(numRows int) []uint32 {
|
||||
ret := make([]uint32, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, rand.Uint32())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
@ -19,16 +19,13 @@ package integration
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
@ -39,59 +36,24 @@ import (
|
|||
)
|
||||
|
||||
func TestRangeSearchIP(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||
c, err := StartMiniCluster(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = c.Start()
|
||||
assert.NoError(t, err)
|
||||
defer c.Stop()
|
||||
defer func() {
|
||||
err = c.Stop()
|
||||
assert.NoError(t, err)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
prefix := "TestRangeSearchIP"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
rowNum := 3000
|
||||
|
||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
AutoID: true,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
pk,
|
||||
fVec,
|
||||
},
|
||||
}
|
||||
}
|
||||
schema := constructCollectionSchema()
|
||||
schema := constructSchema(collectionName, dim, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -133,6 +95,7 @@ func TestRangeSearchIP(t *testing.T) {
|
|||
})
|
||||
assert.NoError(t, err)
|
||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||
assert.True(t, has)
|
||||
ids := segmentIDs.GetData()
|
||||
assert.NotEmpty(t, segmentIDs)
|
||||
|
||||
|
@ -142,52 +105,14 @@ func TestRangeSearchIP(t *testing.T) {
|
|||
for _, segment := range segments {
|
||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||
}
|
||||
|
||||
if has && len(ids) > 0 {
|
||||
flushed := func() bool {
|
||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
||||
SegmentIDs: ids,
|
||||
})
|
||||
if err != nil {
|
||||
//panic(errors.New("GetFlushState failed"))
|
||||
return false
|
||||
}
|
||||
return resp.GetFlushed()
|
||||
}
|
||||
for !flushed() {
|
||||
// respect context deadline/cancel
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
panic(errors.New("deadline exceeded"))
|
||||
default:
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
waitingForFlush(ctx, c, ids)
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: floatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.IP,
|
||||
},
|
||||
{
|
||||
Key: "index_type",
|
||||
Value: "IVF_FLAT",
|
||||
},
|
||||
{
|
||||
Key: "nlist",
|
||||
Value: strconv.Itoa(10),
|
||||
},
|
||||
},
|
||||
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
|
@ -205,34 +130,21 @@ func TestRangeSearchIP(t *testing.T) {
|
|||
if err != nil {
|
||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||
}
|
||||
for {
|
||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
if err != nil {
|
||||
panic("GetLoadingProgress fail")
|
||||
}
|
||||
if loadProgress.GetProgress() == 100 {
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
waitingForLoad(ctx, c, collectionName)
|
||||
// search
|
||||
expr := fmt.Sprintf("%s > 0", "int64")
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
nprobe := 10
|
||||
radius := 10
|
||||
filter := 20
|
||||
|
||||
params := make(map[string]int)
|
||||
params["nprobe"] = nprobe
|
||||
params := getSearchParams(IndexFaissIvfFlat, distance.IP)
|
||||
|
||||
// only pass in radius when range search
|
||||
params["radius"] = radius
|
||||
searchReq := constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -245,7 +157,7 @@ func TestRangeSearchIP(t *testing.T) {
|
|||
// pass in radius and range_filter when range search
|
||||
params["range_filter"] = filter
|
||||
searchReq = constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -259,7 +171,7 @@ func TestRangeSearchIP(t *testing.T) {
|
|||
params["radius"] = filter
|
||||
params["range_filter"] = radius
|
||||
searchReq = constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -277,59 +189,24 @@ func TestRangeSearchIP(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRangeSearchL2(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||
c, err := StartMiniCluster(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = c.Start()
|
||||
assert.NoError(t, err)
|
||||
defer c.Stop()
|
||||
defer func() {
|
||||
err = c.Stop()
|
||||
assert.NoError(t, err)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
prefix := "TestRangeSearchL2"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
rowNum := 3000
|
||||
|
||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
AutoID: true,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
pk,
|
||||
fVec,
|
||||
},
|
||||
}
|
||||
}
|
||||
schema := constructCollectionSchema()
|
||||
schema := constructSchema(collectionName, dim, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -371,6 +248,7 @@ func TestRangeSearchL2(t *testing.T) {
|
|||
})
|
||||
assert.NoError(t, err)
|
||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||
assert.True(t, has)
|
||||
ids := segmentIDs.GetData()
|
||||
assert.NotEmpty(t, segmentIDs)
|
||||
|
||||
|
@ -380,52 +258,14 @@ func TestRangeSearchL2(t *testing.T) {
|
|||
for _, segment := range segments {
|
||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||
}
|
||||
|
||||
if has && len(ids) > 0 {
|
||||
flushed := func() bool {
|
||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
||||
SegmentIDs: ids,
|
||||
})
|
||||
if err != nil {
|
||||
//panic(errors.New("GetFlushState failed"))
|
||||
return false
|
||||
}
|
||||
return resp.GetFlushed()
|
||||
}
|
||||
for !flushed() {
|
||||
// respect context deadline/cancel
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
panic(errors.New("deadline exceeded"))
|
||||
default:
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
waitingForFlush(ctx, c, ids)
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: floatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: "index_type",
|
||||
Value: "IVF_FLAT",
|
||||
},
|
||||
{
|
||||
Key: "nlist",
|
||||
Value: strconv.Itoa(10),
|
||||
},
|
||||
},
|
||||
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.L2),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
|
@ -443,34 +283,20 @@ func TestRangeSearchL2(t *testing.T) {
|
|||
if err != nil {
|
||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||
}
|
||||
for {
|
||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
if err != nil {
|
||||
panic("GetLoadingProgress fail")
|
||||
}
|
||||
if loadProgress.GetProgress() == 100 {
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
waitingForLoad(ctx, c, collectionName)
|
||||
// search
|
||||
expr := fmt.Sprintf("%s > 0", "int64")
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
nprobe := 10
|
||||
radius := 20
|
||||
filter := 10
|
||||
|
||||
params := make(map[string]int)
|
||||
params["nprobe"] = nprobe
|
||||
|
||||
params := getSearchParams(IndexFaissIvfFlat, distance.L2)
|
||||
// only pass in radius when range search
|
||||
params["radius"] = radius
|
||||
searchReq := constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -483,7 +309,7 @@ func TestRangeSearchL2(t *testing.T) {
|
|||
// pass in radius and range_filter when range search
|
||||
params["range_filter"] = filter
|
||||
searchReq = constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -497,7 +323,7 @@ func TestRangeSearchL2(t *testing.T) {
|
|||
params["radius"] = filter
|
||||
params["range_filter"] = radius
|
||||
searchReq = constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
|
|
@ -19,13 +19,10 @@ package integration
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
@ -38,59 +35,25 @@ import (
|
|||
)
|
||||
|
||||
func TestUpsert(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*180)
|
||||
c, err := StartMiniCluster(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = c.Start()
|
||||
assert.NoError(t, err)
|
||||
defer c.Stop()
|
||||
defer func() {
|
||||
err = c.Stop()
|
||||
assert.NoError(t, err)
|
||||
cancel()
|
||||
}()
|
||||
assert.NoError(t, err)
|
||||
|
||||
prefix := "TestUpsert"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
rowNum := 3000
|
||||
|
||||
constructCollectionSchema := func() *schemapb.CollectionSchema {
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
pk,
|
||||
fVec,
|
||||
},
|
||||
}
|
||||
}
|
||||
schema := constructCollectionSchema()
|
||||
schema := constructSchema(collectionName, dim, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -113,7 +76,7 @@ func TestUpsert(t *testing.T) {
|
|||
assert.True(t, merr.Ok(showCollectionsResp.GetStatus()))
|
||||
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
|
||||
|
||||
pkFieldData := newInt64PrimaryKey(int64Field, rowNum)
|
||||
pkFieldData := newInt64FieldData(int64Field, rowNum)
|
||||
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
||||
hashKeys := generateHashKeys(rowNum)
|
||||
upsertResult, err := c.proxy.Upsert(ctx, &milvuspb.UpsertRequest{
|
||||
|
@ -133,6 +96,7 @@ func TestUpsert(t *testing.T) {
|
|||
})
|
||||
assert.NoError(t, err)
|
||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||
assert.True(t, has)
|
||||
ids := segmentIDs.GetData()
|
||||
assert.NotEmpty(t, segmentIDs)
|
||||
|
||||
|
@ -142,52 +106,14 @@ func TestUpsert(t *testing.T) {
|
|||
for _, segment := range segments {
|
||||
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||
}
|
||||
|
||||
if has && len(ids) > 0 {
|
||||
flushed := func() bool {
|
||||
resp, err := c.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
||||
SegmentIDs: ids,
|
||||
})
|
||||
if err != nil {
|
||||
//panic(errors.New("GetFlushState failed"))
|
||||
return false
|
||||
}
|
||||
return resp.GetFlushed()
|
||||
}
|
||||
for !flushed() {
|
||||
// respect context deadline/cancel
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
panic(errors.New("deadline exceeded"))
|
||||
default:
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
waitingForFlush(ctx, c, ids)
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: floatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: "index_type",
|
||||
Value: "IVF_FLAT",
|
||||
},
|
||||
{
|
||||
Key: "nlist",
|
||||
Value: strconv.Itoa(10),
|
||||
},
|
||||
},
|
||||
ExtraParams: constructIndexParam(dim, IndexFaissIvfFlat, distance.IP),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
|
@ -205,30 +131,16 @@ func TestUpsert(t *testing.T) {
|
|||
if err != nil {
|
||||
log.Warn("LoadCollection fail reason", zap.Error(err))
|
||||
}
|
||||
for {
|
||||
loadProgress, err := c.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
if err != nil {
|
||||
panic("GetLoadingProgress fail")
|
||||
}
|
||||
if loadProgress.GetProgress() == 100 {
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
waitingForLoad(ctx, c, collectionName)
|
||||
// search
|
||||
expr := fmt.Sprintf("%s > 0", "int64")
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
nprobe := 10
|
||||
|
||||
params := make(map[string]int)
|
||||
params["nprobe"] = nprobe
|
||||
|
||||
params := getSearchParams(IndexFaissIvfFlat, "")
|
||||
searchReq := constructSearchRequest("", collectionName, expr,
|
||||
floatVecField, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
floatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ := c.proxy.Search(ctx, searchReq)
|
||||
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
||||
)
|
||||
|
||||
const (
|
||||
IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat
|
||||
IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ
|
||||
IndexFaissIDMap = indexparamcheck.IndexFaissIDMap
|
||||
IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat
|
||||
IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ
|
||||
IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8
|
||||
IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap
|
||||
IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat
|
||||
IndexHNSW = indexparamcheck.IndexHNSW
|
||||
IndexDISKANN = indexparamcheck.IndexDISKANN
|
||||
)
|
||||
|
||||
func constructIndexParam(dim int, indexType string, metricType string) []*commonpb.KeyValuePair {
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metricType,
|
||||
},
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: indexType,
|
||||
},
|
||||
}
|
||||
switch indexType {
|
||||
case IndexFaissIDMap, IndexFaissBinIDMap:
|
||||
// no index param is required
|
||||
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8:
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: "nlist",
|
||||
Value: "100",
|
||||
})
|
||||
case IndexFaissIvfPQ:
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: "nlist",
|
||||
Value: "100",
|
||||
})
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: "m",
|
||||
Value: "16",
|
||||
})
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: "nbits",
|
||||
Value: "8",
|
||||
})
|
||||
case IndexHNSW:
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: "M",
|
||||
Value: "16",
|
||||
})
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: "efConstruction",
|
||||
Value: "200",
|
||||
})
|
||||
case IndexDISKANN:
|
||||
default:
|
||||
panic(fmt.Sprintf("unimplemented index param for %s, please help to improve it", indexType))
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func getSearchParams(indexType string, metricType string) map[string]any {
|
||||
params := make(map[string]any)
|
||||
switch indexType {
|
||||
case IndexFaissIDMap, IndexFaissBinIDMap:
|
||||
params["metric_type"] = metricType
|
||||
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ:
|
||||
params["nprobe"] = 8
|
||||
case IndexHNSW:
|
||||
params["ef"] = 200
|
||||
case IndexDISKANN:
|
||||
params["search_list"] = 5
|
||||
default:
|
||||
panic(fmt.Sprintf("unimplemented search param for %s, please help to improve it", indexType))
|
||||
}
|
||||
return params
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
)
|
||||
|
||||
func waitingForFlush(ctx context.Context, cluster *MiniCluster, segIDs []int64) {
|
||||
flushed := func() bool {
|
||||
resp, err := cluster.proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
||||
SegmentIDs: segIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return resp.GetFlushed()
|
||||
}
|
||||
for !flushed() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
panic("flush timeout")
|
||||
default:
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newInt64FieldData(fieldName string, numRows int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: generateInt64Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newStringFieldData(fieldName string, numRows int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: generateStringArray(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: generateFloatVectors(numRows, dim),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: generateBinaryVectors(numRows, dim),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func generateInt64Array(numRows int) []int64 {
|
||||
ret := make([]int64, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret[i] = int64(i)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateStringArray(numRows int) []string {
|
||||
ret := make([]string, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret[i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateFloatVectors(numRows, dim int) []float32 {
|
||||
total := numRows * dim
|
||||
ret := make([]float32, 0, total)
|
||||
for i := 0; i < total; i++ {
|
||||
ret = append(ret, rand.Float32())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateBinaryVectors(numRows, dim int) []byte {
|
||||
total := (numRows * dim) / 8
|
||||
ret := make([]byte, total)
|
||||
_, err := rand.Read(ret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateHashKeys(numRows int) []uint32 {
|
||||
ret := make([]uint32, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, rand.Uint32())
|
||||
}
|
||||
return ret
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
NQKey = "nq"
|
||||
MetricTypeKey = "metric_type"
|
||||
SearchParamsKey = "params"
|
||||
RoundDecimalKey = "round_decimal"
|
||||
OffsetKey = "offset"
|
||||
LimitKey = "limit"
|
||||
)
|
||||
|
||||
func waitingForLoad(ctx context.Context, cluster *MiniCluster, collection string) {
|
||||
getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
|
||||
loadProgress, err := cluster.proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: collection,
|
||||
})
|
||||
if err != nil {
|
||||
panic("GetLoadingProgress fail")
|
||||
}
|
||||
return loadProgress
|
||||
}
|
||||
for getLoadingProgress().GetProgress() != 100 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
panic("load timeout")
|
||||
default:
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func constructSearchRequest(
|
||||
dbName, collectionName string,
|
||||
expr string,
|
||||
vecField string,
|
||||
vectorType schemapb.DataType,
|
||||
outputFields []string,
|
||||
metricType string,
|
||||
params map[string]any,
|
||||
nq, dim int, topk, roundDecimal int,
|
||||
) *milvuspb.SearchRequest {
|
||||
b, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
plg := constructPlaceholderGroup(nq, dim, vectorType)
|
||||
plgBs, err := proto.Marshal(plg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &milvuspb.SearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: nil,
|
||||
Dsl: expr,
|
||||
PlaceholderGroup: plgBs,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: outputFields,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metricType,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
Value: string(b),
|
||||
},
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: vecField,
|
||||
},
|
||||
{
|
||||
Key: common.TopKKey,
|
||||
Value: strconv.Itoa(topk),
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Value: strconv.Itoa(roundDecimal),
|
||||
},
|
||||
},
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup {
|
||||
values := make([][]byte, 0, nq)
|
||||
var placeholderType commonpb.PlaceholderType
|
||||
switch vectorType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
placeholderType = commonpb.PlaceholderType_FloatVector
|
||||
for i := 0; i < nq; i++ {
|
||||
bs := make([]byte, 0, dim*4)
|
||||
for j := 0; j < dim; j++ {
|
||||
var buffer bytes.Buffer
|
||||
f := rand.Float32()
|
||||
err := binary.Write(&buffer, common.Endian, f)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bs = append(bs, buffer.Bytes()...)
|
||||
}
|
||||
values = append(values, bs)
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
placeholderType = commonpb.PlaceholderType_BinaryVector
|
||||
for i := 0; i < nq; i++ {
|
||||
total := dim / 8
|
||||
ret := make([]byte, total)
|
||||
_, err := rand.Read(ret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
values = append(values, ret)
|
||||
}
|
||||
default:
|
||||
panic("invalid vector data type")
|
||||
}
|
||||
|
||||
return &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
{
|
||||
Tag: "$0",
|
||||
Type: placeholderType,
|
||||
Values: values,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
const (
|
||||
boolField = "boolField"
|
||||
int8Field = "int8Field"
|
||||
int16Field = "int16Field"
|
||||
int32Field = "int32Field"
|
||||
int64Field = "int64Field"
|
||||
floatField = "floatField"
|
||||
doubleField = "doubleField"
|
||||
varCharField = "varCharField"
|
||||
floatVecField = "floatVecField"
|
||||
binVecField = "binVecField"
|
||||
)
|
||||
|
||||
func constructSchema(collection string, dim int, autoID bool, fields ...*schemapb.FieldSchema) *schemapb.CollectionSchema {
|
||||
// if fields are specified, construct it
|
||||
if len(fields) > 0 {
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collection,
|
||||
AutoID: autoID,
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
// if no field is specified, use default
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 100,
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
AutoID: autoID,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: fmt.Sprintf("%d", dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
}
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collection,
|
||||
AutoID: autoID,
|
||||
Fields: []*schemapb.FieldSchema{pk, fVec},
|
||||
}
|
||||
}
|
|
@ -842,11 +842,7 @@ class TestCollectionSearchInvalid(TestcaseBase):
|
|||
log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name)
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
default_search_params, default_limit,
|
||||
default_search_exp, output_fields=output_fields,
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": 1,
|
||||
"err_msg": "Search doesn't support "
|
||||
"vector field as output_fields"})
|
||||
default_search_exp, output_fields=output_fields)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])
|
||||
|
|
Loading…
Reference in New Issue