Add support for getting vectors by ids (#23450) (#25180)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
hnsw_get_vec
yihao.dai 2023-06-28 10:06:46 +08:00 committed by GitHub
parent 936ebf3266
commit 785bab5ba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1187 additions and 36 deletions

View File

@ -293,6 +293,9 @@ rpm: install
@cp -r build/rpm/services ~/rpmbuild/BUILD/
@QA_RPATHS="$$[ 0x001|0x0002|0x0020 ]" rpmbuild -ba ./build/rpm/milvus.spec
mock-proxy:
mockery --name=ProxyComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy.go --structname=Proxy --with-expecter
mock-datanode:
mockery --name=DataNode --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode.go --with-expecter

View File

@ -75,6 +75,16 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) {
return true;
}
inline DatasetPtr
GenIdsDataset(const int64_t count, const int64_t* ids) {
auto ret_ds = std::make_shared<knowhere::Dataset>();
knowhere::SetDatasetRows(ret_ds, count);
knowhere::SetDatasetDim(ret_ds, 1);
// INPUT_IDS will not be free in dataset destructor, which is similar to `SetIsOwner(false)`.
knowhere::SetDatasetInputIDs(ret_ds, ids);
return ret_ds;
}
inline bool
PostfixMatch(const std::string_view str, const std::string_view postfix) {
if (postfix.length() > str.length()) {

View File

@ -187,6 +187,33 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
return result;
}
template <typename T>
const bool
VectorDiskAnnIndex<T>::HasRawData() const {
return index_->HasRawData(GetMetricType());
}
template <typename T>
std::vector<uint8_t>
VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset, const Config& config) const {
auto res = index_->GetVectorById(dataset, config);
AssertInfo(res != nullptr, "failed to get vector, result is null");
auto index_type = GetIndexType();
auto tensor = knowhere::GetDatasetTensor(res);
auto row_num = knowhere::GetDatasetRows(res);
auto dim = knowhere::GetDatasetDim(res);
int64_t data_size;
if (is_in_bin_list(index_type)) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
memcpy(raw_data.data(), tensor, data_size);
return raw_data;
}
template <typename T>
void
VectorDiskAnnIndex<T>::CleanLocalData() {

View File

@ -17,6 +17,7 @@
#pragma once
#include <memory>
#include <vector>
#include "index/VectorIndex.h"
#include "storage/DiskFileManagerImpl.h"
@ -68,6 +69,12 @@ class VectorDiskAnnIndex : public VectorIndex {
std::unique_ptr<SearchResult>
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
const bool
HasRawData() const override;
std::vector<uint8_t>
GetVector(const DatasetPtr dataset, const Config& config = {}) const override;
void
CleanLocalData() override;

View File

@ -19,6 +19,7 @@
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <boost/dynamic_bitset.hpp>
#include "knowhere/index/VecIndex.h"
@ -45,6 +46,12 @@ class VectorIndex : public IndexBase {
virtual std::unique_ptr<SearchResult>
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) = 0;
virtual const bool
HasRawData() const = 0;
virtual std::vector<uint8_t>
GetVector(const DatasetPtr dataset, const Config& config = {}) const = 0;
IndexType
GetIndexType() const {
return index_type_;

View File

@ -222,4 +222,29 @@ VectorMemIndex::parse_config(Config& config) {
CheckParameter<int>(config, knowhere::indexparam::SEARCH_K, stoi_closure, std::nullopt);
}
const bool
VectorMemIndex::HasRawData() const {
return index_->HasRawData(GetMetricType());
}
std::vector<uint8_t>
VectorMemIndex::GetVector(const DatasetPtr dataset, const Config& config) const {
auto res = index_->GetVectorById(dataset, config);
AssertInfo(res != nullptr, "failed to get vector, result is null");
auto index_type = GetIndexType();
auto tensor = knowhere::GetDatasetOutputTensor(res);
auto row_num = knowhere::GetDatasetRows(res);
auto dim = knowhere::GetDatasetDim(res);
int64_t data_size;
if (is_in_bin_list(index_type)) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
memcpy(raw_data.data(), tensor, data_size);
return raw_data;
}
} // namespace milvus::index

View File

@ -67,6 +67,12 @@ class VectorMemIndex : public VectorIndex {
virtual void
LoadWithoutAssemble(const BinarySet& binary_set, const Config& config);
const bool
HasRawData() const override;
std::vector<uint8_t>
GetVector(const DatasetPtr dataset, const Config& config = {}) const override;
protected:
Config config_;
knowhere::VecIndexPtr index_ = nullptr;

View File

@ -208,6 +208,11 @@ class SegmentGrowingImpl : public SegmentGrowing {
return true;
}
bool
HasRawData(int64_t field_id) const override {
return true;
}
protected:
int64_t
num_chunk() const override;

View File

@ -87,6 +87,9 @@ class SegmentInterface {
virtual int64_t
get_segment_id() const = 0;
virtual bool
HasRawData(int64_t field_id) const = 0;
};
// internal API for DSL calculation

View File

@ -392,6 +392,30 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
}
}
std::unique_ptr<DataArray>
SegmentSealedImpl::get_vector(FieldId field_id, const int64_t* ids, int64_t count) const {
auto& filed_meta = schema_->operator[](field_id);
AssertInfo(filed_meta.is_vector(), "vector field is not vector type");
if (get_bit(index_ready_bitset_, field_id)) {
AssertInfo(vector_indexings_.is_ready(field_id), "vector index is not ready");
auto field_indexing = vector_indexings_.get_field_indexing(field_id);
auto vec_index = dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
auto index_type = vec_index->GetIndexType();
auto metric_type = vec_index->GetMetricType();
auto has_raw_data = vec_index->HasRawData();
if (has_raw_data) {
auto ids_ds = GenIdsDataset(count, ids);
auto vector = vec_index->GetVector(ids_ds);
return segcore::CreateVectorDataArrayFrom(vector.data(), count, filed_meta);
}
}
return fill_with_empty(field_id, count);
}
void
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
if (SystemProperty::Instance().IsSystem(field_id)) {
@ -553,9 +577,7 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
}
// TODO: knowhere support reverse data from vector index
// Now, real data will be filled in data array using chunk manager
return fill_with_empty(field_id, count);
return get_vector(field_id, seg_offsets, count);
}
Assert(get_bit(field_data_ready_bitset_, field_id));
@ -649,6 +671,22 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const {
}
}
bool
SegmentSealedImpl::HasRawData(int64_t field_id) const {
std::shared_lock lck(mutex_);
auto fieldID = FieldId(field_id);
const auto& field_meta = schema_->operator[](fieldID);
if (datatype_is_vector(field_meta.get_data_type())) {
if (get_bit(index_ready_bitset_, fieldID)) {
AssertInfo(vector_indexings_.is_ready(fieldID), "vector index is not ready");
auto field_indexing = vector_indexings_.get_field_indexing(fieldID);
auto vec_index = dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
return vec_index->HasRawData();
}
}
return true;
}
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) const {
AssertInfo(id_array.has_int_id(), "Id array doesn't have int_id element");

View File

@ -59,6 +59,9 @@ class SegmentSealedImpl : public SegmentSealed {
return id_;
}
bool
HasRawData(int64_t field_id) const override;
public:
int64_t
GetMemoryUsageInBytes() const override;
@ -72,6 +75,9 @@ class SegmentSealedImpl : public SegmentSealed {
const Schema&
get_schema() const override;
std::unique_ptr<DataArray>
get_vector(FieldId field_id, const int64_t* ids, int64_t count) const;
public:
int64_t
num_chunk_index(FieldId field_id) const override;

View File

@ -138,6 +138,12 @@ GetRealCount(CSegmentInterface c_segment) {
return segment->get_real_count();
}
bool
HasRawData(CSegmentInterface c_segment, int64_t field_id) {
auto segment = reinterpret_cast<milvus::segcore::SegmentInterface*>(c_segment);
return segment->HasRawData(field_id);
}
////////////////////////////// interfaces for growing segment //////////////////////////////
CStatus
Insert(CSegmentInterface c_segment,

View File

@ -63,6 +63,9 @@ GetDeletedCount(CSegmentInterface c_segment);
int64_t
GetRealCount(CSegmentInterface c_segment);
bool
HasRawData(CSegmentInterface c_segment, int64_t field_id);
////////////////////////////// interfaces for growing segment //////////////////////////////
CStatus
Insert(CSegmentInterface c_segment,

View File

@ -11,8 +11,8 @@
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
set(KNOWHERE_VERSION v1.3.16)
set(KNOWHERE_SOURCE_MD5 "f0ded5a77b39ca7db7047191234ec6d7")
set(KNOWHERE_VERSION v1.3.17)
set(KNOWHERE_SOURCE_MD5 "00164cd97b2f35c09ae0bdd6e2a9fc02")
if (DEFINED ENV{MILVUS_KNOWHERE_URL})
set(KNOWHERE_SOURCE_URL "$ENV{MILVUS_KNOWHERE_URL}")

View File

@ -430,6 +430,92 @@ TEST_P(IndexTest, BuildAndQuery) {
}
}
TEST_P(IndexTest, GetVector) {
milvus::index::CreateIndexInfo create_index_info;
create_index_info.index_type = index_type;
create_index_info.metric_type = metric_type;
create_index_info.field_type = vec_field_data_type;
index::IndexBasePtr index;
// only HNSW support getVector in 2.2.x
if (index_type != knowhere::IndexEnum::INDEX_HNSW) {
return;
}
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
#ifdef BUILD_DISK_ANN
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
auto file_manager =
std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config_);
index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
#endif
} else {
index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
}
ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
milvus::index::IndexBasePtr new_index;
milvus::index::VectorIndex* vec_index = nullptr;
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
#ifdef BUILD_DISK_ANN
// TODO ::diskann.query need load first, ugly
auto binary_set = index->Serialize(milvus::Config{});
index.reset();
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
auto file_manager =
std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config_);
new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
std::vector<std::string> index_files;
for (auto& binary : binary_set.binary_map_) {
index_files.emplace_back(binary.first);
}
load_conf["index_files"] = index_files;
vec_index->Load(binary_set, load_conf);
EXPECT_EQ(vec_index->Count(), NB);
#endif
} else {
vec_index = dynamic_cast<milvus::index::VectorIndex*>(index.get());
}
EXPECT_EQ(vec_index->GetDim(), DIM);
EXPECT_EQ(vec_index->Count(), NB);
if (!vec_index->HasRawData()) {
return;
}
auto ids_ds = GenRandomIds(NB);
auto results = vec_index->GetVector(ids_ds);
EXPECT_TRUE(results.size() > 0);
if (!is_binary) {
std::vector<float> result_vectors(results.size() / (sizeof(float)));
memcpy(result_vectors.data(), results.data(), results.size());
EXPECT_TRUE(result_vectors.size() == xb_data.size());
for (size_t i = 0; i < NB; ++i) {
auto id = knowhere::GetDatasetInputIDs(ids_ds)[i];
for (size_t j = 0; j < DIM; ++j) {
EXPECT_TRUE(result_vectors[i * DIM + j] == xb_data[id * DIM + j]);
}
}
} else {
EXPECT_TRUE(results.size() == xb_bin_data.size());
const auto data_bytes = DIM / 8;
for (size_t i = 0; i < NB; ++i) {
auto id = knowhere::GetDatasetInputIDs(ids_ds)[i];
for (size_t j = 0; j < data_bytes; ++j) {
EXPECT_TRUE(results[i * data_bytes + j] == xb_bin_data[id * data_bytes + j]);
}
}
}
auto ids = knowhere::GetDatasetInputIDs(ids_ds);
delete[] ids;
}
//#ifdef BUILD_DISK_ANN
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
// int64_t NB = 10000;

View File

@ -811,3 +811,54 @@ TEST(Sealed, RealCount) {
ASSERT_TRUE(status.ok());
ASSERT_EQ(0, segment->get_real_count());
}
TEST(Sealed, GetVector) {
auto dim = 16;
auto topK = 5;
auto N = ROW_COUNT;
auto metric_type = knowhere::metric::L2;
auto schema = std::make_shared<Schema>();
auto fakevec_id = schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto counter_id = schema->AddDebugField("counter", DataType::INT64);
auto double_id = schema->AddDebugField("double", DataType::DOUBLE);
auto nothing_id = schema->AddDebugField("nothing", DataType::INT32);
auto str_id = schema->AddDebugField("str", DataType::VARCHAR);
schema->AddDebugField("int8", DataType::INT8);
schema->AddDebugField("int16", DataType::INT16);
schema->AddDebugField("float", DataType::FLOAT);
schema->set_primary_field_id(counter_id);
auto dataset = DataGen(schema, N);
auto fakevec = dataset.get_col<float>(fakevec_id);
auto indexing = GenHNSWIndex(N, dim, fakevec.data());
auto segment_sealed = CreateSealedSegment(schema);
LoadIndexInfo vec_info;
vec_info.field_id = fakevec_id.get();
vec_info.index = std::move(indexing);
vec_info.index_params["metric_type"] = knowhere::metric::L2;
segment_sealed->LoadIndex(vec_info);
auto segment = dynamic_cast<SegmentSealedImpl*>(segment_sealed.get());
auto has = segment->HasRawData(vec_info.field_id);
EXPECT_TRUE(has);
auto ids_ds = GenRandomIds(N);
auto result = segment->get_vector(fakevec_id, knowhere::GetDatasetInputIDs(ids_ds), N);
auto vector = result.get()->mutable_vectors()->float_vector().data();
EXPECT_TRUE(vector.size() == fakevec.size());
for (size_t i = 0; i < N; ++i) {
auto id = knowhere::GetDatasetInputIDs(ids_ds)[i];
for (size_t j = 0; j < dim; ++j) {
EXPECT_TRUE(vector[i * dim + j] == fakevec[id * dim + j]);
}
}
auto ids = knowhere::GetDatasetInputIDs(ids_ds);
delete[] ids;
}

View File

@ -615,6 +615,20 @@ GenVecIndexing(int64_t N, int64_t dim, const float* vec) {
return indexing;
}
inline index::VectorIndexPtr
GenHNSWIndex(int64_t N, int64_t dim, const float* vec) {
auto conf = knowhere::Config{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, std::to_string(dim)},
{knowhere::indexparam::EF, "200"},
{knowhere::indexparam::M, "16"},
{knowhere::meta::DEVICE_ID, 0}};
auto database = knowhere::GenDataset(N, dim, vec);
auto indexing = std::make_unique<index::VectorMemIndex>(knowhere::IndexEnum::INDEX_HNSW,
knowhere::metric::L2, IndexMode::MODE_CPU);
indexing->BuildWithDataset(database, conf);
return indexing;
}
template <typename T>
inline index::IndexBasePtr
GenScalarIndexing(int64_t N, const T* data) {
@ -680,4 +694,15 @@ GenPKs(const std::vector<int64_t>& pks) {
return GenPKs(pks.begin(), pks.end());
}
inline std::shared_ptr<knowhere::Dataset>
GenRandomIds(int rows, int64_t seed = 42) {
std::mt19937 g(seed);
auto* ids = new int64_t[rows];
for (int i = 0; i < rows; ++i) ids[i] = i;
std::shuffle(ids, ids + rows, g);
// INPUT_IDS will not be free in dataset destructor, please delete it manually.
auto ids_ds = GenIdsDataset(rows, ids);
return ids_ds;
}
} // namespace milvus::segcore

View File

@ -3019,9 +3019,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
ReqID: Params.ProxyCfg.GetNodeID(),
},
request: request,
qc: node.queryCoord,
tr: timerecord.NewTimeRecorder("search"),
shardMgr: node.shardMgr,
qc: node.queryCoord,
node: node,
}
travelTs := request.TravelTimestamp

View File

@ -8,6 +8,7 @@ import (
"strings"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -506,7 +507,10 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
case *schemapb.IDs_IntId:
idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
case *schemapb.IDs_StrId:
idsStr = strings.Trim(strings.Join(ids.GetStrId().GetData(), ", "), "[]")
strs := lo.Map(ids.GetStrId().GetData(), func(str string, _ int) string {
return fmt.Sprintf("\"%s\"", str)
})
idsStr = strings.Trim(strings.Join(strs, ", "), "[]")
}
return fieldName + " in [ " + idsStr + " ]"

View File

@ -690,3 +690,28 @@ func Test_filterSystemFields(t *testing.T) {
filtered := filterSystemFields(outputFieldIDs)
assert.ElementsMatch(t, []UniqueID{common.StartOfUserFieldID}, filtered)
}
func TestQueryTask_IDs2Expr(t *testing.T) {
fieldName := "pk"
intIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4, 5},
},
},
}
stringIDs := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"a", "b", "c"},
},
},
}
idExpr := IDs2Expr(fieldName, intIDs)
expectIDExpr := "pk in [ 1, 2, 3, 4, 5 ]"
assert.Equal(t, expectIDExpr, idExpr)
strExpr := IDs2Expr(fieldName, stringIDs)
expectStrExpr := "pk in [ \"a\", \"b\", \"c\" ]"
assert.Equal(t, expectStrExpr, strExpr)
}

View File

@ -4,10 +4,12 @@ import (
"context"
"errors"
"fmt"
"math"
"regexp"
"strconv"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -36,6 +38,12 @@ import (
const (
SearchTaskName = "SearchTask"
SearchLevelKey = "level"
// requeryThreshold is the estimated threshold for the size of the search results.
// If the number of estimated search results exceeds this threshold,
// a second query request will be initiated to retrieve output fields data.
// In this case, the first search will not return any output field from QueryNodes.
requeryThreshold = 0.5 * 1024 * 1024
)
type searchTask struct {
@ -43,12 +51,13 @@ type searchTask struct {
*internalpb.SearchRequest
ctx context.Context
result *milvuspb.SearchResults
request *milvuspb.SearchRequest
qc types.QueryCoord
result *milvuspb.SearchResults
request *milvuspb.SearchRequest
tr *timerecord.TimeRecorder
collectionName string
schema *schemapb.CollectionSchema
requery bool
offset int64
resultBuf chan *internalpb.SearchResults
@ -57,6 +66,9 @@ type searchTask struct {
searchShardPolicy pickShardPolicy
shardMgr *shardClientMgr
qc types.QueryCoord
node types.ProxyComponent
}
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
@ -166,11 +178,7 @@ func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string)
hitField := false
for _, field := range schema.GetFields() {
if field.Name == name {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
return nil, errors.New("search doesn't support vector field as output_fields")
}
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
hitField = true
break
}
@ -272,6 +280,24 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.IgnoreGrowing = ignoreGrowing
partitionNames := t.request.GetPartitionNames()
// Manually update nq if not set.
nq, err := getNq(t.request)
if err != nil {
return err
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateNQLimit(nq); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
t.SearchRequest.Nq = nq
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
if err != nil {
return err
}
t.SearchRequest.OutputFieldsId = outputFieldIDs
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
if err != nil || len(annsField) == 0 {
@ -325,6 +351,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.Topk = queryInfo.GetTopk()
t.SearchRequest.MetricType = queryInfo.GetMetricType()
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
if err != nil {
return err
}
if estimateSize >= requeryThreshold {
t.requery = true
plan.OutputFieldIds = nil
}
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
return err
@ -381,17 +417,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.SearchRequest.Dsl = t.request.Dsl
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
// Manually update nq if not set.
nq, err := getNq(t.request)
if err != nil {
return err
}
// Check if nq is valid:
// https://milvus.io/docs/limitations.md
if err := validateNQLimit(nq); err != nil {
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
}
t.SearchRequest.Nq = nq
log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()),
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
@ -504,6 +529,17 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
t.fillInFieldInfo()
t.result.Results.OutputFields = t.userOutputFields
log.Ctx(ctx).Debug("Search post execute done", zap.Int64("msgID", t.ID()))
if t.requery {
err = t.Requery()
if err != nil {
return err
}
}
log.Ctx(ctx).Debug("Search post execute done",
zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()))
return nil
}
@ -551,6 +587,99 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
return nil
}
func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
})
// Currently, we get vectors by requery. Once we support getting vectors from search,
// searches with small result size could no longer need requery.
if len(vectorOutputFields) > 0 {
return math.MaxInt64, nil
}
// If no vector field as output, no need to requery.
return 0, nil
//outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
// return lo.Contains(t.request.GetOutputFields(), field.GetName())
//})
//sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields})
//if err != nil {
// return 0, err
//}
//return int64(sizePerRecord) * nq * topK, nil
}
func (t *searchTask) Requery() error {
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema)
if err != nil {
return err
}
ids := t.result.GetResults().GetIds()
expr := IDs2Expr(pkField.GetName(), ids)
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
CollectionName: t.request.GetCollectionName(),
Expr: expr,
OutputFields: t.request.GetOutputFields(),
PartitionNames: t.request.GetPartitionNames(),
TravelTimestamp: t.request.GetTravelTimestamp(),
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
QueryParams: t.request.GetSearchParams(),
}
queryResult, err := t.node.Query(t.ctx, queryReq)
if err != nil {
return err
}
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return common.NewCodeError(queryResult.GetStatus().GetErrorCode(),
fmt.Errorf("requery failed, err=%s", queryResult.GetStatus().GetReason()))
}
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
// We should reorganize query results to keep the order of original queried ids. For example:
// ===========================================
// 3 2 5 4 1 (query ids)
// ||
// || (query)
// \/
// 4 3 5 1 2 (result ids)
// v4 v3 v5 v1 v2 (result vectors)
// ||
// || (reorganize)
// \/
// 3 2 5 4 1 (result ids)
// v3 v2 v5 v4 v1 (result vectors)
// ===========================================
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
if err != nil {
return err
}
offsets := make(map[any]int)
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
pk := typeutil.GetData(pkFieldData, i)
offsets[pk] = i
}
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
id := typeutil.GetPK(ids, int64(i))
if _, ok := offsets[id]; !ok {
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID())
}
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
}
// filter id field out if it is not specified as output
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
return lo.Contains(t.request.GetOutputFields(), fieldData.GetFieldName())
})
return nil
}
func (t *searchTask) fillInEmptyResult(numQueries int64) {
t.result = &milvuspb.SearchResults{
Status: &commonpb.Status{

View File

@ -19,6 +19,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
@ -325,7 +326,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
// contain vector field
task.request.OutputFields = []string{testFloatVecField}
assert.Error(t, task.PreExecute(ctx))
assert.NoError(t, task.PreExecute(ctx))
})
}
@ -2079,3 +2080,292 @@ func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
}
return &result
}
func TestSearchTask_Requery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
const (
dim = 128
rows = 5
collection = "test-requery"
pkField = "pk"
vecField = "vec"
)
ids := make([]int64, rows)
for i := range ids {
ids[i] = int64(i)
}
t.Run("Test normal", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{{
Type: schemapb.DataType_Int64,
FieldName: pkField,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: ids,
},
},
},
},
},
newFloatVectorFieldData(vecField, rows, dim),
},
}, nil)
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
outputFields := []string{vecField}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{
OutputFields: outputFields,
},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,
},
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
assert.NoError(t, err)
assert.Len(t, qt.result.Results.FieldsData, 1)
assert.Equal(t, vecField, qt.result.Results.FieldsData[0].GetFieldName())
})
t.Run("Test no primary key", func(t *testing.T) {
schema := &schemapb.CollectionSchema{}
node := mocks.NewProxy(t)
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test requery failed 1", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test requery failed 2", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock err 2",
},
}, nil)
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test get pk filed data failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{},
}, nil)
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test incomplete query result", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(&milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{{
Type: schemapb.DataType_Int64,
FieldName: pkField,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: ids[:len(ids)-1],
},
},
},
},
},
newFloatVectorFieldData(vecField, rows, dim),
},
}, nil)
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,
},
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
err := qt.Requery()
t.Logf("err = %s", err)
assert.Error(t, err)
})
t.Run("Test postExecute with requery failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
node := mocks.NewProxy(t)
node.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: &milvuspb.SearchRequest{},
result: &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
Ids: resultIDs,
},
},
requery: true,
schema: schema,
resultBuf: make(chan *internalpb.SearchResults, 10),
tr: timerecord.NewTimeRecorder("search"),
node: node,
}
scores := make([]float32, rows)
for i := range scores {
scores[i] = float32(i)
}
partialResultData := &schemapb.SearchResultData{
Ids: resultIDs,
Scores: scores,
}
bytes, err := proto.Marshal(partialResultData)
assert.NoError(t, err)
qt.resultBuf <- &internalpb.SearchResults{
SlicedBlob: bytes,
}
err = qt.PostExecute(ctx)
t.Logf("err = %s", err)
assert.Error(t, err)
})
}

View File

@ -207,6 +207,16 @@ func (s *Segment) setUnhealthy() {
s.destroyed.Store(true)
}
func (s *Segment) hasRawData(fieldID int64) bool {
s.mut.RLock()
defer s.mut.RUnlock()
if !s.healthy() {
return false
}
ret := C.HasRawData(s.segmentPtr, C.int64_t(fieldID))
return bool(ret)
}
func newSegment(collection *Collection,
segmentID UniqueID,
partitionID UniqueID,
@ -607,10 +617,18 @@ func (s *Segment) fillIndexedFieldsData(ctx context.Context, collectionID Unique
vcm storage.ChunkManager, result *segcorepb.RetrieveResults) error {
for _, fieldData := range result.FieldsData {
// If the vector field doesn't have indexed. Vector data is in memory for
// brute force search. No need to download data from remote.
if fieldData.GetType() != schemapb.DataType_FloatVector && fieldData.GetType() != schemapb.DataType_BinaryVector ||
!s.hasLoadIndexForIndexedField(fieldData.FieldId) {
// If the field is not vector field, no need to download data from remote.
if !typeutil.IsVectorType(fieldData.GetType()) {
continue
}
// If the vector field doesn't have indexed, vector data is in memory
// for brute force search, no need to download data from remote.
if !s.hasLoadIndexForIndexedField(fieldData.FieldId) {
continue
}
// If the index has raw data, vector data could be obtained from index,
// no need to download data from remote.
if s.hasRawData(fieldData.FieldId) {
continue
}

View File

@ -660,7 +660,7 @@ func TestSegment_fillIndexedFieldsData(t *testing.T) {
FieldsData: fieldData,
}
err = segment.fillIndexedFieldsData(ctx, defaultCollectionID, vecCM, result)
assert.Error(t, err)
assert.NoError(t, err)
})
}
@ -1038,3 +1038,41 @@ func TestDeleteBuff(t *testing.T) {
assert.NoError(t, err)
})
}
func TestHasRawData(t *testing.T) {
t.Run("growing", func(t *testing.T) {
schema := genTestCollectionSchema()
collection := newCollection(defaultCollectionID, schema)
segment, err := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultDMLChannel,
segmentTypeGrowing,
defaultSegmentVersion,
defaultSegmentStartPosition,
)
assert.Nil(t, err)
has := segment.hasRawData(simpleFloatVecField.id)
assert.True(t, has)
})
t.Run("sealed", func(t *testing.T) {
schema := genTestCollectionSchema()
collection := newCollection(defaultCollectionID, schema)
segment, err := newSegment(collection,
defaultSegmentID,
defaultPartitionID,
defaultCollectionID,
defaultDMLChannel,
segmentTypeSealed,
defaultSegmentVersion,
defaultSegmentStartPosition,
)
assert.Nil(t, err)
has := segment.hasRawData(simpleFloatVecField.id)
assert.True(t, has)
})
}

View File

@ -814,6 +814,16 @@ func GetSizeOfIDs(data *schemapb.IDs) int {
return result
}
func GetPKSize(fieldData *schemapb.FieldData) int {
switch fieldData.GetType() {
case schemapb.DataType_Int64:
return len(fieldData.GetScalars().GetLongData().GetData())
case schemapb.DataType_VarChar:
return len(fieldData.GetScalars().GetStringData().GetData())
}
return 0
}
func IsPrimaryFieldType(dataType schemapb.DataType) bool {
if dataType == schemapb.DataType_Int64 || dataType == schemapb.DataType_VarChar {
return true
@ -849,6 +859,31 @@ func GetTS(i *internalpb.RetrieveResults, idx int64) uint64 {
return 0
}
func GetData(field *schemapb.FieldData, idx int) interface{} {
switch field.GetType() {
case schemapb.DataType_Bool:
return field.GetScalars().GetBoolData().GetData()[idx]
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return field.GetScalars().GetIntData().GetData()[idx]
case schemapb.DataType_Int64:
return field.GetScalars().GetLongData().GetData()[idx]
case schemapb.DataType_Float:
return field.GetScalars().GetFloatData().GetData()[idx]
case schemapb.DataType_Double:
return field.GetScalars().GetDoubleData().GetData()[idx]
case schemapb.DataType_VarChar:
return field.GetScalars().GetStringData().GetData()[idx]
case schemapb.DataType_FloatVector:
dim := int(field.GetVectors().GetDim())
return field.GetVectors().GetFloatVector().GetData()[idx*dim : (idx+1)*dim]
case schemapb.DataType_BinaryVector:
dim := int(field.GetVectors().GetDim())
dataBytes := dim / 8
return field.GetVectors().GetBinaryVector()[idx*dataBytes : (idx+1)*dataBytes]
}
return nil
}
func AppendPKs(pks *schemapb.IDs, pk interface{}) {
switch realPK := pk.(type) {
case int64:

View File

@ -1111,3 +1111,69 @@ func TestMergeFieldData(t *testing.T) {
MergeFieldData([]*schemapb.FieldData{emptyField}, []*schemapb.FieldData{emptyField})
}
func TestGetDataAndGetDataSize(t *testing.T) {
const (
Dim = 8
fieldName = "filed-0"
fieldID = 0
)
BoolArray := []bool{true, false}
Int8Array := []int8{1, 2}
Int16Array := []int16{3, 4}
Int32Array := []int32{5, 6}
Int64Array := []int64{11, 22}
FloatArray := []float32{1.0, 2.0}
DoubleArray := []float64{11.0, 22.0}
VarCharArray := []string{"a", "b"}
BinaryVector := []byte{0x12, 0x34}
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
boolData := genFieldData(fieldName, fieldID, schemapb.DataType_Bool, BoolArray, 1)
int8Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int8, Int8Array, 1)
int16Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int16, Int16Array, 1)
int32Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int32, Int32Array, 1)
int64Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int64, Int64Array, 1)
floatData := genFieldData(fieldName, fieldID, schemapb.DataType_Float, FloatArray, 1)
doubleData := genFieldData(fieldName, fieldID, schemapb.DataType_Double, DoubleArray, 1)
varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1)
binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim)
floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim)
invalidData := &schemapb.FieldData{
Type: schemapb.DataType_None,
}
t.Run("test GetPKSize", func(t *testing.T) {
int64DataRes := GetPKSize(int64Data)
varCharDataRes := GetPKSize(varCharData)
assert.Equal(t, 2, int64DataRes)
assert.Equal(t, 2, varCharDataRes)
})
t.Run("test GetData", func(t *testing.T) {
boolDataRes := GetData(boolData, 0)
int8DataRes := GetData(int8Data, 0)
int16DataRes := GetData(int16Data, 0)
int32DataRes := GetData(int32Data, 0)
int64DataRes := GetData(int64Data, 0)
floatDataRes := GetData(floatData, 0)
doubleDataRes := GetData(doubleData, 0)
varCharDataRes := GetData(varCharData, 0)
binVecDataRes := GetData(binVecData, 0)
floatVecDataRes := GetData(floatVecData, 0)
invalidDataRes := GetData(invalidData, 0)
assert.Equal(t, BoolArray[0], boolDataRes)
assert.Equal(t, int32(Int8Array[0]), int8DataRes)
assert.Equal(t, int32(Int16Array[0]), int16DataRes)
assert.Equal(t, Int32Array[0], int32DataRes)
assert.Equal(t, Int64Array[0], int64DataRes)
assert.Equal(t, FloatArray[0], floatDataRes)
assert.Equal(t, DoubleArray[0], doubleDataRes)
assert.Equal(t, VarCharArray[0], varCharDataRes)
assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes)
assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes)
assert.Nil(t, invalidDataRes)
})
}

View File

@ -230,6 +230,7 @@ binary_support = ["BIN_FLAT", "BIN_IVF_FLAT"]
delete_support = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]
ivf = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]
skip_pq = ["IVF_PQ"]
float_metrics = ["L2", "IP"]
binary_metrics = ["JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
structure_metrics = ["SUBSTRUCTURE", "SUPERSTRUCTURE"]

View File

@ -4,6 +4,8 @@ import random
import pytest
import pandas as pd
import decimal
from decimal import Decimal, getcontext
from time import sleep
from base.client_base import TestcaseBase
@ -812,11 +814,7 @@ class TestCollectionSearchInvalid(TestcaseBase):
log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name)
collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit,
default_search_exp, output_fields=output_fields,
check_task=CheckTasks.err_res,
check_items={"err_code": 1,
"err_msg": "Search doesn't support "
"vector field as output_fields"})
default_search_exp, output_fields=output_fields)
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])
@ -2644,6 +2642,33 @@ class TestCollectionSearch(TestcaseBase):
assert len(res[0][0].entity._row_data) != 0
assert default_int64_field_name in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L1)
def test_search_with_output_vector_field(self, auto_id, _async):
"""
target: test search with output fields
method: search with one output_field
expected: search success
"""
# 1. initialize with data
collection_w, _, _, insert_ids = self.init_collection_general(prefix, True,
auto_id=auto_id)[0:4]
# 2. search
log.info("test_search_with_output_field: Searching collection %s" % collection_w.name)
res = collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit,
default_search_exp, _async=_async,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
"ids": insert_ids,
"limit": default_limit,
"_async": _async})[0]
if _async:
res.done()
res = res.result()
assert len(res[0][0].entity._row_data) != 0
assert field_name in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L2)
def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async):
"""
@ -2675,6 +2700,183 @@ class TestCollectionSearch(TestcaseBase):
assert len(res[0][0].entity._row_data) != 0
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("index, params",
zip(ct.all_index_types[:6],
ct.default_index_params[:6]))
@pytest.mark.parametrize("metrics", ct.float_metrics)
def test_search_output_field_vector_after_different_index_metrics(self, index, params, metrics):
"""
target: test search with output vector field after different index
method: 1. create a collection and insert data
2. create index and load
3. search with output field vector
4. check the result vectors should be equal to the inserted
expected: search success
"""
# 1. create a collection and insert data
collection_w = self.init_collection_general(prefix, is_index=True)[0]
data = cf.gen_default_dataframe_data()
collection_w.insert(data)
# 2. create index and load
default_index = {"index_type": index, "params": params, "metric_type": metrics}
collection_w.create_index(field_name, default_index)
collection_w.load()
# 3. search with output field vector
search_params = cf.gen_search_param(index, metrics)[0]
res = collection_w.search(vectors[:1], default_search_field,
search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit})[0]
# 4. check the result vectors should be equal to the inserted
for _id in range(default_limit):
for i in range(default_dim):
vectorInsert = str(data[field_name][res[0][_id].id][i])[:7]
vectorRes = str(res[0][_id].entity.float_vector[i])[:7]
if vectorInsert != vectorRes:
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
assert str(vectorInsert) == vectorRes
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip(reason="issue #23661")
@pytest.mark.parametrize("index", ct.all_index_types[6:8])
def test_search_output_field_vector_after_binary_index(self, index):
"""
target: test search with output vector field after binary index
method: 1. create a collection and insert data
2. create index and load
3. search with output field vector
4. check the result vectors should be equal to the inserted
expected: search success
"""
# 1. create a collection and insert data
collection_w = self.init_collection_general(prefix, is_binary=True, is_index=False)[0]
data = cf.gen_default_binary_dataframe_data()[0]
collection_w.insert(data)
# 2. create index and load
default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "JACCARD"}
collection_w.create_index(binary_field_name, default_index)
collection_w.load()
# 3. search with output field vector
search_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
binary_vectors = cf.gen_binary_vectors(1, default_dim)[1]
res = collection_w.search(binary_vectors, binary_field_name,
ct.default_search_binary_params, 2, default_search_exp,
output_fields=[binary_field_name])[0]
# 4. check the result vectors should be equal to the inserted
log.info(res[0][0].id)
log.info(res[0][0].entity.float_vector)
log.info(data['binary_vector'][0])
assert res[0][0].entity.binary_vector == data[binary_field_name][res[0][0].id]
# log.info(data['float_vector'][1])
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("dim", [32, 128, 768])
def test_search_output_field_vector_with_different_dim(self, dim):
"""
target: test search with output vector field after binary index
method: 1. create a collection and insert data
2. create index and load
3. search with output field vector
4. check the result vectors should be equal to the inserted
expected: search success
"""
# 1. create a collection and insert data
collection_w = self.init_collection_general(prefix, is_index=True, dim=dim)[0]
data = cf.gen_default_dataframe_data(dim=dim)
collection_w.insert(data)
# 2. create index and load
index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
collection_w.create_index("float_vector", index_params)
collection_w.load()
# 3. search with output field vector
vectors = cf.gen_vectors(default_nq, dim=dim)
res = collection_w.search(vectors[:default_nq], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
"limit": default_limit})[0]
# 4. check the result vectors should be equal to the inserted
for i in range(default_limit):
assert len(res[0][i].entity.float_vector) == len(data[field_name][res[0][i].id])
@pytest.mark.tags(CaseLabel.L2)
def test_search_output_vector_field_and_scalar_field(self):
"""
target: test search with output vector field and scalar field
method: 1. initialize a collection
2. search with output field vector
3. check no field missing
expected: search success
"""
# 1. initialize a collection
collection_w = self.init_collection_general(prefix, True)[0]
# 2. search with output field vector
res = collection_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[default_float_field_name,
default_string_field_name,
default_search_field],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit})[0]
# 3. check the result
assert default_float_field_name, default_string_field_name in res[0][0].entity._row_data
assert default_search_field in res[0][0].entity._row_data
@pytest.mark.tags(CaseLabel.L2)
def test_search_output_field_vector_with_partition(self):
"""
target: test search with output vector field
method: 1. create a collection and insert data
2. create index and load
3. search with output field vector
4. check the result vectors should be equal to the inserted
expected: search success
"""
# 1. create a collection and insert data
collection_w = self.init_collection_general(prefix, is_index=True)[0]
partition_w = self.init_partition_wrap(collection_w)
data = cf.gen_default_dataframe_data()
partition_w.insert(data)
# 2. create index and load
collection_w.create_index(field_name, default_index_params)
collection_w.load()
# 3. search with output field vector
res = partition_w.search(vectors[:1], default_search_field,
default_search_params, default_limit, default_search_exp,
output_fields=[field_name],
check_task=CheckTasks.check_search_results,
check_items={"nq": 1,
"limit": default_limit})[0]
# 4. check the result vectors should be equal to the inserted
for _id in range(default_limit):
for i in range(default_dim):
vectorInsert = str(data[field_name][res[0][_id].id][i])[:7]
vectorRes = str(res[0][_id].entity.float_vector[i])[:7]
if vectorInsert != vectorRes:
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
assert str(vectorInsert) == vectorRes
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("output_fields", [["*"], ["*", default_float_field_name]])
def test_search_with_output_field_wildcard(self, output_fields, auto_id, _async, enable_dynamic_field):
@ -4978,6 +5180,40 @@ class TestsearchDiskann(TestcaseBase):
"limit": limit,
"_async": _async})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.xfail(reason="issue #23672")
def test_search_diskann_search_list_up_to_min(self, dim, auto_id, _async):
"""
target: test search diskann index when search_list up to min
method: 1.create collection , insert data, primary_field is int field
2.create diskann index , then load
3.search
expected: search successfully
"""
# 1. initialize with data
collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id,
dim=dim, is_index=False)[0:4]
# 2. create index
default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}}
collection_w.create_index(ct.default_float_vec_field_name, default_index)
collection_w.load()
search_params = {"metric_type": "L2", "params": {"k": 200, "search_list": 201}}
search_vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
collection_w.search(search_vectors[:default_nq], default_search_field,
search_params, default_limit,
default_search_exp,
output_fields=output_fields,
_async=_async,
travel_timestamp=0,
check_task=CheckTasks.check_search_results,
check_items={"nq": default_nq,
"ids": insert_ids,
"limit": default_limit,
"_async": _async})
class TestCollectionSearchJSON(TestcaseBase):
""" Test case of search interface """