mirror of https://github.com/milvus-io/milvus.git
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>hnsw_get_vec
parent
936ebf3266
commit
785bab5ba7
3
Makefile
3
Makefile
|
@ -293,6 +293,9 @@ rpm: install
|
|||
@cp -r build/rpm/services ~/rpmbuild/BUILD/
|
||||
@QA_RPATHS="$$[ 0x001|0x0002|0x0020 ]" rpmbuild -ba ./build/rpm/milvus.spec
|
||||
|
||||
mock-proxy:
|
||||
mockery --name=ProxyComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy.go --structname=Proxy --with-expecter
|
||||
|
||||
mock-datanode:
|
||||
mockery --name=DataNode --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode.go --with-expecter
|
||||
|
||||
|
|
|
@ -75,6 +75,16 @@ PrefixMatch(const std::string_view str, const std::string_view prefix) {
|
|||
return true;
|
||||
}
|
||||
|
||||
inline DatasetPtr
|
||||
GenIdsDataset(const int64_t count, const int64_t* ids) {
|
||||
auto ret_ds = std::make_shared<knowhere::Dataset>();
|
||||
knowhere::SetDatasetRows(ret_ds, count);
|
||||
knowhere::SetDatasetDim(ret_ds, 1);
|
||||
// INPUT_IDS will not be free in dataset destructor, which is similar to `SetIsOwner(false)`.
|
||||
knowhere::SetDatasetInputIDs(ret_ds, ids);
|
||||
return ret_ds;
|
||||
}
|
||||
|
||||
inline bool
|
||||
PostfixMatch(const std::string_view str, const std::string_view postfix) {
|
||||
if (postfix.length() > str.length()) {
|
||||
|
|
|
@ -187,6 +187,33 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset, const SearchInfo& search_
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const bool
|
||||
VectorDiskAnnIndex<T>::HasRawData() const {
|
||||
return index_->HasRawData(GetMetricType());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<uint8_t>
|
||||
VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset, const Config& config) const {
|
||||
auto res = index_->GetVectorById(dataset, config);
|
||||
AssertInfo(res != nullptr, "failed to get vector, result is null");
|
||||
auto index_type = GetIndexType();
|
||||
auto tensor = knowhere::GetDatasetTensor(res);
|
||||
auto row_num = knowhere::GetDatasetRows(res);
|
||||
auto dim = knowhere::GetDatasetDim(res);
|
||||
int64_t data_size;
|
||||
if (is_in_bin_list(index_type)) {
|
||||
data_size = dim / 8 * row_num;
|
||||
} else {
|
||||
data_size = dim * row_num * sizeof(float);
|
||||
}
|
||||
std::vector<uint8_t> raw_data;
|
||||
raw_data.resize(data_size);
|
||||
memcpy(raw_data.data(), tensor, data_size);
|
||||
return raw_data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
VectorDiskAnnIndex<T>::CleanLocalData() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "index/VectorIndex.h"
|
||||
#include "storage/DiskFileManagerImpl.h"
|
||||
|
@ -68,6 +69,12 @@ class VectorDiskAnnIndex : public VectorIndex {
|
|||
std::unique_ptr<SearchResult>
|
||||
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) override;
|
||||
|
||||
const bool
|
||||
HasRawData() const override;
|
||||
|
||||
std::vector<uint8_t>
|
||||
GetVector(const DatasetPtr dataset, const Config& config = {}) const override;
|
||||
|
||||
void
|
||||
CleanLocalData() override;
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
|
||||
#include "knowhere/index/VecIndex.h"
|
||||
|
@ -45,6 +46,12 @@ class VectorIndex : public IndexBase {
|
|||
virtual std::unique_ptr<SearchResult>
|
||||
Query(const DatasetPtr dataset, const SearchInfo& search_info, const BitsetView& bitset) = 0;
|
||||
|
||||
virtual const bool
|
||||
HasRawData() const = 0;
|
||||
|
||||
virtual std::vector<uint8_t>
|
||||
GetVector(const DatasetPtr dataset, const Config& config = {}) const = 0;
|
||||
|
||||
IndexType
|
||||
GetIndexType() const {
|
||||
return index_type_;
|
||||
|
|
|
@ -222,4 +222,29 @@ VectorMemIndex::parse_config(Config& config) {
|
|||
CheckParameter<int>(config, knowhere::indexparam::SEARCH_K, stoi_closure, std::nullopt);
|
||||
}
|
||||
|
||||
const bool
|
||||
VectorMemIndex::HasRawData() const {
|
||||
return index_->HasRawData(GetMetricType());
|
||||
}
|
||||
|
||||
std::vector<uint8_t>
|
||||
VectorMemIndex::GetVector(const DatasetPtr dataset, const Config& config) const {
|
||||
auto res = index_->GetVectorById(dataset, config);
|
||||
AssertInfo(res != nullptr, "failed to get vector, result is null");
|
||||
auto index_type = GetIndexType();
|
||||
auto tensor = knowhere::GetDatasetOutputTensor(res);
|
||||
auto row_num = knowhere::GetDatasetRows(res);
|
||||
auto dim = knowhere::GetDatasetDim(res);
|
||||
int64_t data_size;
|
||||
if (is_in_bin_list(index_type)) {
|
||||
data_size = dim / 8 * row_num;
|
||||
} else {
|
||||
data_size = dim * row_num * sizeof(float);
|
||||
}
|
||||
std::vector<uint8_t> raw_data;
|
||||
raw_data.resize(data_size);
|
||||
memcpy(raw_data.data(), tensor, data_size);
|
||||
return raw_data;
|
||||
}
|
||||
|
||||
} // namespace milvus::index
|
||||
|
|
|
@ -67,6 +67,12 @@ class VectorMemIndex : public VectorIndex {
|
|||
virtual void
|
||||
LoadWithoutAssemble(const BinarySet& binary_set, const Config& config);
|
||||
|
||||
const bool
|
||||
HasRawData() const override;
|
||||
|
||||
std::vector<uint8_t>
|
||||
GetVector(const DatasetPtr dataset, const Config& config = {}) const override;
|
||||
|
||||
protected:
|
||||
Config config_;
|
||||
knowhere::VecIndexPtr index_ = nullptr;
|
||||
|
|
|
@ -208,6 +208,11 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
HasRawData(int64_t field_id) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
int64_t
|
||||
num_chunk() const override;
|
||||
|
|
|
@ -87,6 +87,9 @@ class SegmentInterface {
|
|||
|
||||
virtual int64_t
|
||||
get_segment_id() const = 0;
|
||||
|
||||
virtual bool
|
||||
HasRawData(int64_t field_id) const = 0;
|
||||
};
|
||||
|
||||
// internal API for DSL calculation
|
||||
|
|
|
@ -392,6 +392,30 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
|
|||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
SegmentSealedImpl::get_vector(FieldId field_id, const int64_t* ids, int64_t count) const {
|
||||
auto& filed_meta = schema_->operator[](field_id);
|
||||
AssertInfo(filed_meta.is_vector(), "vector field is not vector type");
|
||||
|
||||
if (get_bit(index_ready_bitset_, field_id)) {
|
||||
AssertInfo(vector_indexings_.is_ready(field_id), "vector index is not ready");
|
||||
auto field_indexing = vector_indexings_.get_field_indexing(field_id);
|
||||
auto vec_index = dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
||||
|
||||
auto index_type = vec_index->GetIndexType();
|
||||
auto metric_type = vec_index->GetMetricType();
|
||||
auto has_raw_data = vec_index->HasRawData();
|
||||
|
||||
if (has_raw_data) {
|
||||
auto ids_ds = GenIdsDataset(count, ids);
|
||||
auto vector = vec_index->GetVector(ids_ds);
|
||||
return segcore::CreateVectorDataArrayFrom(vector.data(), count, filed_meta);
|
||||
}
|
||||
}
|
||||
|
||||
return fill_with_empty(field_id, count);
|
||||
}
|
||||
|
||||
void
|
||||
SegmentSealedImpl::DropFieldData(const FieldId field_id) {
|
||||
if (SystemProperty::Instance().IsSystem(field_id)) {
|
||||
|
@ -553,9 +577,7 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets,
|
|||
return ReverseDataFromIndex(index, seg_offsets, count, field_meta);
|
||||
}
|
||||
|
||||
// TODO: knowhere support reverse data from vector index
|
||||
// Now, real data will be filled in data array using chunk manager
|
||||
return fill_with_empty(field_id, count);
|
||||
return get_vector(field_id, seg_offsets, count);
|
||||
}
|
||||
|
||||
Assert(get_bit(field_data_ready_bitset_, field_id));
|
||||
|
@ -649,6 +671,22 @@ SegmentSealedImpl::HasFieldData(FieldId field_id) const {
|
|||
}
|
||||
}
|
||||
|
||||
bool
|
||||
SegmentSealedImpl::HasRawData(int64_t field_id) const {
|
||||
std::shared_lock lck(mutex_);
|
||||
auto fieldID = FieldId(field_id);
|
||||
const auto& field_meta = schema_->operator[](fieldID);
|
||||
if (datatype_is_vector(field_meta.get_data_type())) {
|
||||
if (get_bit(index_ready_bitset_, fieldID)) {
|
||||
AssertInfo(vector_indexings_.is_ready(fieldID), "vector index is not ready");
|
||||
auto field_indexing = vector_indexings_.get_field_indexing(fieldID);
|
||||
auto vec_index = dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
||||
return vec_index->HasRawData();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
|
||||
SegmentSealedImpl::search_ids(const IdArray& id_array, Timestamp timestamp) const {
|
||||
AssertInfo(id_array.has_int_id(), "Id array doesn't have int_id element");
|
||||
|
|
|
@ -59,6 +59,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
|||
return id_;
|
||||
}
|
||||
|
||||
bool
|
||||
HasRawData(int64_t field_id) const override;
|
||||
|
||||
public:
|
||||
int64_t
|
||||
GetMemoryUsageInBytes() const override;
|
||||
|
@ -72,6 +75,9 @@ class SegmentSealedImpl : public SegmentSealed {
|
|||
const Schema&
|
||||
get_schema() const override;
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
get_vector(FieldId field_id, const int64_t* ids, int64_t count) const;
|
||||
|
||||
public:
|
||||
int64_t
|
||||
num_chunk_index(FieldId field_id) const override;
|
||||
|
|
|
@ -138,6 +138,12 @@ GetRealCount(CSegmentInterface c_segment) {
|
|||
return segment->get_real_count();
|
||||
}
|
||||
|
||||
bool
|
||||
HasRawData(CSegmentInterface c_segment, int64_t field_id) {
|
||||
auto segment = reinterpret_cast<milvus::segcore::SegmentInterface*>(c_segment);
|
||||
return segment->HasRawData(field_id);
|
||||
}
|
||||
|
||||
////////////////////////////// interfaces for growing segment //////////////////////////////
|
||||
CStatus
|
||||
Insert(CSegmentInterface c_segment,
|
||||
|
|
|
@ -63,6 +63,9 @@ GetDeletedCount(CSegmentInterface c_segment);
|
|||
int64_t
|
||||
GetRealCount(CSegmentInterface c_segment);
|
||||
|
||||
bool
|
||||
HasRawData(CSegmentInterface c_segment, int64_t field_id);
|
||||
|
||||
////////////////////////////// interfaces for growing segment //////////////////////////////
|
||||
CStatus
|
||||
Insert(CSegmentInterface c_segment,
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
set(KNOWHERE_VERSION v1.3.16)
|
||||
set(KNOWHERE_SOURCE_MD5 "f0ded5a77b39ca7db7047191234ec6d7")
|
||||
set(KNOWHERE_VERSION v1.3.17)
|
||||
set(KNOWHERE_SOURCE_MD5 "00164cd97b2f35c09ae0bdd6e2a9fc02")
|
||||
|
||||
if (DEFINED ENV{MILVUS_KNOWHERE_URL})
|
||||
set(KNOWHERE_SOURCE_URL "$ENV{MILVUS_KNOWHERE_URL}")
|
||||
|
|
|
@ -430,6 +430,92 @@ TEST_P(IndexTest, BuildAndQuery) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_P(IndexTest, GetVector) {
|
||||
milvus::index::CreateIndexInfo create_index_info;
|
||||
create_index_info.index_type = index_type;
|
||||
create_index_info.metric_type = metric_type;
|
||||
create_index_info.field_type = vec_field_data_type;
|
||||
index::IndexBasePtr index;
|
||||
|
||||
// only HNSW support getVector in 2.2.x
|
||||
if (index_type != knowhere::IndexEnum::INDEX_HNSW) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
|
||||
#ifdef BUILD_DISK_ANN
|
||||
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||
auto file_manager =
|
||||
std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config_);
|
||||
index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
|
||||
#endif
|
||||
} else {
|
||||
index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, nullptr);
|
||||
}
|
||||
ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf));
|
||||
milvus::index::IndexBasePtr new_index;
|
||||
milvus::index::VectorIndex* vec_index = nullptr;
|
||||
|
||||
if (index_type == knowhere::IndexEnum::INDEX_DISKANN) {
|
||||
#ifdef BUILD_DISK_ANN
|
||||
// TODO ::diskann.query need load first, ugly
|
||||
auto binary_set = index->Serialize(milvus::Config{});
|
||||
index.reset();
|
||||
milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100};
|
||||
milvus::storage::IndexMeta index_meta{3, 100, 1000, 1};
|
||||
auto file_manager =
|
||||
std::make_shared<milvus::storage::DiskFileManagerImpl>(field_data_meta, index_meta, storage_config_);
|
||||
new_index = milvus::index::IndexFactory::GetInstance().CreateIndex(create_index_info, file_manager);
|
||||
|
||||
vec_index = dynamic_cast<milvus::index::VectorIndex*>(new_index.get());
|
||||
|
||||
std::vector<std::string> index_files;
|
||||
for (auto& binary : binary_set.binary_map_) {
|
||||
index_files.emplace_back(binary.first);
|
||||
}
|
||||
load_conf["index_files"] = index_files;
|
||||
vec_index->Load(binary_set, load_conf);
|
||||
EXPECT_EQ(vec_index->Count(), NB);
|
||||
#endif
|
||||
} else {
|
||||
vec_index = dynamic_cast<milvus::index::VectorIndex*>(index.get());
|
||||
}
|
||||
EXPECT_EQ(vec_index->GetDim(), DIM);
|
||||
EXPECT_EQ(vec_index->Count(), NB);
|
||||
|
||||
if (!vec_index->HasRawData()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto ids_ds = GenRandomIds(NB);
|
||||
auto results = vec_index->GetVector(ids_ds);
|
||||
EXPECT_TRUE(results.size() > 0);
|
||||
if (!is_binary) {
|
||||
std::vector<float> result_vectors(results.size() / (sizeof(float)));
|
||||
memcpy(result_vectors.data(), results.data(), results.size());
|
||||
EXPECT_TRUE(result_vectors.size() == xb_data.size());
|
||||
for (size_t i = 0; i < NB; ++i) {
|
||||
auto id = knowhere::GetDatasetInputIDs(ids_ds)[i];
|
||||
for (size_t j = 0; j < DIM; ++j) {
|
||||
EXPECT_TRUE(result_vectors[i * DIM + j] == xb_data[id * DIM + j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
EXPECT_TRUE(results.size() == xb_bin_data.size());
|
||||
const auto data_bytes = DIM / 8;
|
||||
for (size_t i = 0; i < NB; ++i) {
|
||||
auto id = knowhere::GetDatasetInputIDs(ids_ds)[i];
|
||||
for (size_t j = 0; j < data_bytes; ++j) {
|
||||
EXPECT_TRUE(results[i * data_bytes + j] == xb_bin_data[id * data_bytes + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto ids = knowhere::GetDatasetInputIDs(ids_ds);
|
||||
delete[] ids;
|
||||
}
|
||||
|
||||
//#ifdef BUILD_DISK_ANN
|
||||
// TEST(Indexing, SearchDiskAnnWithInvalidParam) {
|
||||
// int64_t NB = 10000;
|
||||
|
|
|
@ -811,3 +811,54 @@ TEST(Sealed, RealCount) {
|
|||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(0, segment->get_real_count());
|
||||
}
|
||||
|
||||
TEST(Sealed, GetVector) {
|
||||
auto dim = 16;
|
||||
auto topK = 5;
|
||||
auto N = ROW_COUNT;
|
||||
auto metric_type = knowhere::metric::L2;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto fakevec_id = schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
|
||||
auto counter_id = schema->AddDebugField("counter", DataType::INT64);
|
||||
auto double_id = schema->AddDebugField("double", DataType::DOUBLE);
|
||||
auto nothing_id = schema->AddDebugField("nothing", DataType::INT32);
|
||||
auto str_id = schema->AddDebugField("str", DataType::VARCHAR);
|
||||
schema->AddDebugField("int8", DataType::INT8);
|
||||
schema->AddDebugField("int16", DataType::INT16);
|
||||
schema->AddDebugField("float", DataType::FLOAT);
|
||||
schema->set_primary_field_id(counter_id);
|
||||
|
||||
auto dataset = DataGen(schema, N);
|
||||
|
||||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto indexing = GenHNSWIndex(N, dim, fakevec.data());
|
||||
|
||||
auto segment_sealed = CreateSealedSegment(schema);
|
||||
|
||||
LoadIndexInfo vec_info;
|
||||
vec_info.field_id = fakevec_id.get();
|
||||
vec_info.index = std::move(indexing);
|
||||
vec_info.index_params["metric_type"] = knowhere::metric::L2;
|
||||
segment_sealed->LoadIndex(vec_info);
|
||||
|
||||
auto segment = dynamic_cast<SegmentSealedImpl*>(segment_sealed.get());
|
||||
|
||||
auto has = segment->HasRawData(vec_info.field_id);
|
||||
EXPECT_TRUE(has);
|
||||
|
||||
auto ids_ds = GenRandomIds(N);
|
||||
auto result = segment->get_vector(fakevec_id, knowhere::GetDatasetInputIDs(ids_ds), N);
|
||||
|
||||
auto vector = result.get()->mutable_vectors()->float_vector().data();
|
||||
EXPECT_TRUE(vector.size() == fakevec.size());
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
auto id = knowhere::GetDatasetInputIDs(ids_ds)[i];
|
||||
for (size_t j = 0; j < dim; ++j) {
|
||||
EXPECT_TRUE(vector[i * dim + j] == fakevec[id * dim + j]);
|
||||
}
|
||||
}
|
||||
|
||||
auto ids = knowhere::GetDatasetInputIDs(ids_ds);
|
||||
delete[] ids;
|
||||
}
|
||||
|
|
|
@ -615,6 +615,20 @@ GenVecIndexing(int64_t N, int64_t dim, const float* vec) {
|
|||
return indexing;
|
||||
}
|
||||
|
||||
inline index::VectorIndexPtr
|
||||
GenHNSWIndex(int64_t N, int64_t dim, const float* vec) {
|
||||
auto conf = knowhere::Config{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
|
||||
{knowhere::meta::DIM, std::to_string(dim)},
|
||||
{knowhere::indexparam::EF, "200"},
|
||||
{knowhere::indexparam::M, "16"},
|
||||
{knowhere::meta::DEVICE_ID, 0}};
|
||||
auto database = knowhere::GenDataset(N, dim, vec);
|
||||
auto indexing = std::make_unique<index::VectorMemIndex>(knowhere::IndexEnum::INDEX_HNSW,
|
||||
knowhere::metric::L2, IndexMode::MODE_CPU);
|
||||
indexing->BuildWithDataset(database, conf);
|
||||
return indexing;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline index::IndexBasePtr
|
||||
GenScalarIndexing(int64_t N, const T* data) {
|
||||
|
@ -680,4 +694,15 @@ GenPKs(const std::vector<int64_t>& pks) {
|
|||
return GenPKs(pks.begin(), pks.end());
|
||||
}
|
||||
|
||||
inline std::shared_ptr<knowhere::Dataset>
|
||||
GenRandomIds(int rows, int64_t seed = 42) {
|
||||
std::mt19937 g(seed);
|
||||
auto* ids = new int64_t[rows];
|
||||
for (int i = 0; i < rows; ++i) ids[i] = i;
|
||||
std::shuffle(ids, ids + rows, g);
|
||||
// INPUT_IDS will not be free in dataset destructor, please delete it manually.
|
||||
auto ids_ds = GenIdsDataset(rows, ids);
|
||||
return ids_ds;
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
|
|
@ -3019,9 +3019,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
ReqID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
request: request,
|
||||
qc: node.queryCoord,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
shardMgr: node.shardMgr,
|
||||
qc: node.queryCoord,
|
||||
node: node,
|
||||
}
|
||||
|
||||
travelTs := request.TravelTimestamp
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
@ -506,7 +507,10 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
|
|||
case *schemapb.IDs_IntId:
|
||||
idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
|
||||
case *schemapb.IDs_StrId:
|
||||
idsStr = strings.Trim(strings.Join(ids.GetStrId().GetData(), ", "), "[]")
|
||||
strs := lo.Map(ids.GetStrId().GetData(), func(str string, _ int) string {
|
||||
return fmt.Sprintf("\"%s\"", str)
|
||||
})
|
||||
idsStr = strings.Trim(strings.Join(strs, ", "), "[]")
|
||||
}
|
||||
|
||||
return fieldName + " in [ " + idsStr + " ]"
|
||||
|
|
|
@ -690,3 +690,28 @@ func Test_filterSystemFields(t *testing.T) {
|
|||
filtered := filterSystemFields(outputFieldIDs)
|
||||
assert.ElementsMatch(t, []UniqueID{common.StartOfUserFieldID}, filtered)
|
||||
}
|
||||
|
||||
func TestQueryTask_IDs2Expr(t *testing.T) {
|
||||
fieldName := "pk"
|
||||
intIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
stringIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
}
|
||||
idExpr := IDs2Expr(fieldName, intIDs)
|
||||
expectIDExpr := "pk in [ 1, 2, 3, 4, 5 ]"
|
||||
assert.Equal(t, expectIDExpr, idExpr)
|
||||
|
||||
strExpr := IDs2Expr(fieldName, stringIDs)
|
||||
expectStrExpr := "pk in [ \"a\", \"b\", \"c\" ]"
|
||||
assert.Equal(t, expectStrExpr, strExpr)
|
||||
}
|
||||
|
|
|
@ -4,10 +4,12 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
@ -36,6 +38,12 @@ import (
|
|||
const (
|
||||
SearchTaskName = "SearchTask"
|
||||
SearchLevelKey = "level"
|
||||
|
||||
// requeryThreshold is the estimated threshold for the size of the search results.
|
||||
// If the number of estimated search results exceeds this threshold,
|
||||
// a second query request will be initiated to retrieve output fields data.
|
||||
// In this case, the first search will not return any output field from QueryNodes.
|
||||
requeryThreshold = 0.5 * 1024 * 1024
|
||||
)
|
||||
|
||||
type searchTask struct {
|
||||
|
@ -43,12 +51,13 @@ type searchTask struct {
|
|||
*internalpb.SearchRequest
|
||||
ctx context.Context
|
||||
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.SearchRequest
|
||||
qc types.QueryCoord
|
||||
result *milvuspb.SearchResults
|
||||
request *milvuspb.SearchRequest
|
||||
|
||||
tr *timerecord.TimeRecorder
|
||||
collectionName string
|
||||
schema *schemapb.CollectionSchema
|
||||
requery bool
|
||||
|
||||
offset int64
|
||||
resultBuf chan *internalpb.SearchResults
|
||||
|
@ -57,6 +66,9 @@ type searchTask struct {
|
|||
|
||||
searchShardPolicy pickShardPolicy
|
||||
shardMgr *shardClientMgr
|
||||
|
||||
qc types.QueryCoord
|
||||
node types.ProxyComponent
|
||||
}
|
||||
|
||||
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||
|
@ -166,11 +178,7 @@ func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string)
|
|||
hitField := false
|
||||
for _, field := range schema.GetFields() {
|
||||
if field.Name == name {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
||||
return nil, errors.New("search doesn't support vector field as output_fields")
|
||||
}
|
||||
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
|
||||
|
||||
hitField = true
|
||||
break
|
||||
}
|
||||
|
@ -272,6 +280,24 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||
|
||||
partitionNames := t.request.GetPartitionNames()
|
||||
// Manually update nq if not set.
|
||||
nq, err := getNq(t.request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateNQLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
|
||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
|
||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||
if err != nil || len(annsField) == 0 {
|
||||
|
@ -325,6 +351,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||
|
||||
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if estimateSize >= requeryThreshold {
|
||||
t.requery = true
|
||||
plan.OutputFieldIds = nil
|
||||
}
|
||||
|
||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -381,17 +417,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
|
||||
t.SearchRequest.Dsl = t.request.Dsl
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
// Manually update nq if not set.
|
||||
nq, err := getNq(t.request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateNQLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
|
||||
log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()),
|
||||
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
|
||||
|
@ -504,6 +529,17 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
t.fillInFieldInfo()
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
log.Ctx(ctx).Debug("Search post execute done", zap.Int64("msgID", t.ID()))
|
||||
|
||||
if t.requery {
|
||||
err = t.Requery()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug("Search post execute done",
|
||||
zap.Int64("collection", t.GetCollectionID()),
|
||||
zap.Int64s("partitionIDs", t.GetPartitionIDs()))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -551,6 +587,99 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
|
||||
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
return lo.Contains(t.request.GetOutputFields(), field.GetName()) && typeutil.IsVectorType(field.GetDataType())
|
||||
})
|
||||
// Currently, we get vectors by requery. Once we support getting vectors from search,
|
||||
// searches with small result size could no longer need requery.
|
||||
if len(vectorOutputFields) > 0 {
|
||||
return math.MaxInt64, nil
|
||||
}
|
||||
// If no vector field as output, no need to requery.
|
||||
return 0, nil
|
||||
|
||||
//outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
// return lo.Contains(t.request.GetOutputFields(), field.GetName())
|
||||
//})
|
||||
//sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields})
|
||||
//if err != nil {
|
||||
// return 0, err
|
||||
//}
|
||||
//return int64(sizePerRecord) * nq * topK, nil
|
||||
}
|
||||
|
||||
func (t *searchTask) Requery() error {
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ids := t.result.GetResults().GetIds()
|
||||
expr := IDs2Expr(pkField.GetName(), ids)
|
||||
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
},
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
Expr: expr,
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
PartitionNames: t.request.GetPartitionNames(),
|
||||
TravelTimestamp: t.request.GetTravelTimestamp(),
|
||||
GuaranteeTimestamp: t.request.GetGuaranteeTimestamp(),
|
||||
QueryParams: t.request.GetSearchParams(),
|
||||
}
|
||||
queryResult, err := t.node.Query(t.ctx, queryReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return common.NewCodeError(queryResult.GetStatus().GetErrorCode(),
|
||||
fmt.Errorf("requery failed, err=%s", queryResult.GetStatus().GetReason()))
|
||||
}
|
||||
// Reorganize Results. The order of query result ids will be altered and differ from queried ids.
|
||||
// We should reorganize query results to keep the order of original queried ids. For example:
|
||||
// ===========================================
|
||||
// 3 2 5 4 1 (query ids)
|
||||
// ||
|
||||
// || (query)
|
||||
// \/
|
||||
// 4 3 5 1 2 (result ids)
|
||||
// v4 v3 v5 v1 v2 (result vectors)
|
||||
// ||
|
||||
// || (reorganize)
|
||||
// \/
|
||||
// 3 2 5 4 1 (result ids)
|
||||
// v3 v2 v5 v4 v1 (result vectors)
|
||||
// ===========================================
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(queryResult.GetFieldsData(), pkField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offsets := make(map[any]int)
|
||||
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
||||
pk := typeutil.GetData(pkFieldData, i)
|
||||
offsets[pk] = i
|
||||
}
|
||||
|
||||
t.result.Results.FieldsData = make([]*schemapb.FieldData, len(queryResult.GetFieldsData()))
|
||||
for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ {
|
||||
id := typeutil.GetPK(ids, int64(i))
|
||||
if _, ok := offsets[id]; !ok {
|
||||
return fmt.Errorf("incomplete query result, missing id %s, len(searchIDs) = %d, len(queryIDs) = %d, collection=%d",
|
||||
id, typeutil.GetSizeOfIDs(ids), len(offsets), t.GetCollectionID())
|
||||
}
|
||||
typeutil.AppendFieldData(t.result.Results.FieldsData, queryResult.GetFieldsData(), int64(offsets[id]))
|
||||
}
|
||||
|
||||
// filter id field out if it is not specified as output
|
||||
t.result.Results.FieldsData = lo.Filter(t.result.Results.FieldsData, func(fieldData *schemapb.FieldData, i int) bool {
|
||||
return lo.Contains(t.request.GetOutputFields(), fieldData.GetFieldName())
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||
t.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
@ -325,7 +326,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
|||
|
||||
// contain vector field
|
||||
task.request.OutputFields = []string{testFloatVecField}
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -2079,3 +2080,292 @@ func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
|||
}
|
||||
return &result
|
||||
}
|
||||
|
||||
func TestSearchTask_Requery(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
const (
|
||||
dim = 128
|
||||
rows = 5
|
||||
collection = "test-requery"
|
||||
|
||||
pkField = "pk"
|
||||
vecField = "vec"
|
||||
)
|
||||
|
||||
ids := make([]int64, rows)
|
||||
for i := range ids {
|
||||
ids[i] = int64(i)
|
||||
}
|
||||
|
||||
t.Run("Test normal", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: pkField,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
newFloatVectorFieldData(vecField, rows, dim),
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
outputFields := []string{vecField}
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
OutputFields: outputFields,
|
||||
},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, qt.result.Results.FieldsData, 1)
|
||||
assert.Equal(t, vecField, qt.result.Results.FieldsData[0].GetFieldName())
|
||||
})
|
||||
|
||||
t.Run("Test no primary key", func(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{}
|
||||
node := mocks.NewProxy(t)
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test requery failed 1", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, fmt.Errorf("mock err 1"))
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test requery failed 2", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mock err 2",
|
||||
},
|
||||
}, nil)
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test get pk filed data failed", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
}, nil)
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test incomplete query result", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: pkField,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: ids[:len(ids)-1],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
newFloatVectorFieldData(vecField, rows, dim),
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
|
||||
err := qt.Requery()
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test postExecute with requery failed", func(t *testing.T) {
|
||||
schema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
node := mocks.NewProxy(t)
|
||||
node.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, fmt.Errorf("mock err 1"))
|
||||
|
||||
resultIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qt := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
SourceID: Params.ProxyCfg.GetNodeID(),
|
||||
},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
result: &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
requery: true,
|
||||
schema: schema,
|
||||
resultBuf: make(chan *internalpb.SearchResults, 10),
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
}
|
||||
scores := make([]float32, rows)
|
||||
for i := range scores {
|
||||
scores[i] = float32(i)
|
||||
}
|
||||
partialResultData := &schemapb.SearchResultData{
|
||||
Ids: resultIDs,
|
||||
Scores: scores,
|
||||
}
|
||||
bytes, err := proto.Marshal(partialResultData)
|
||||
assert.NoError(t, err)
|
||||
qt.resultBuf <- &internalpb.SearchResults{
|
||||
SlicedBlob: bytes,
|
||||
}
|
||||
|
||||
err = qt.PostExecute(ctx)
|
||||
t.Logf("err = %s", err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -207,6 +207,16 @@ func (s *Segment) setUnhealthy() {
|
|||
s.destroyed.Store(true)
|
||||
}
|
||||
|
||||
func (s *Segment) hasRawData(fieldID int64) bool {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return false
|
||||
}
|
||||
ret := C.HasRawData(s.segmentPtr, C.int64_t(fieldID))
|
||||
return bool(ret)
|
||||
}
|
||||
|
||||
func newSegment(collection *Collection,
|
||||
segmentID UniqueID,
|
||||
partitionID UniqueID,
|
||||
|
@ -607,10 +617,18 @@ func (s *Segment) fillIndexedFieldsData(ctx context.Context, collectionID Unique
|
|||
vcm storage.ChunkManager, result *segcorepb.RetrieveResults) error {
|
||||
|
||||
for _, fieldData := range result.FieldsData {
|
||||
// If the vector field doesn't have indexed. Vector data is in memory for
|
||||
// brute force search. No need to download data from remote.
|
||||
if fieldData.GetType() != schemapb.DataType_FloatVector && fieldData.GetType() != schemapb.DataType_BinaryVector ||
|
||||
!s.hasLoadIndexForIndexedField(fieldData.FieldId) {
|
||||
// If the field is not vector field, no need to download data from remote.
|
||||
if !typeutil.IsVectorType(fieldData.GetType()) {
|
||||
continue
|
||||
}
|
||||
// If the vector field doesn't have indexed, vector data is in memory
|
||||
// for brute force search, no need to download data from remote.
|
||||
if !s.hasLoadIndexForIndexedField(fieldData.FieldId) {
|
||||
continue
|
||||
}
|
||||
// If the index has raw data, vector data could be obtained from index,
|
||||
// no need to download data from remote.
|
||||
if s.hasRawData(fieldData.FieldId) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -660,7 +660,7 @@ func TestSegment_fillIndexedFieldsData(t *testing.T) {
|
|||
FieldsData: fieldData,
|
||||
}
|
||||
err = segment.fillIndexedFieldsData(ctx, defaultCollectionID, vecCM, result)
|
||||
assert.Error(t, err)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1038,3 +1038,41 @@ func TestDeleteBuff(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasRawData(t *testing.T) {
|
||||
t.Run("growing", func(t *testing.T) {
|
||||
schema := genTestCollectionSchema()
|
||||
collection := newCollection(defaultCollectionID, schema)
|
||||
segment, err := newSegment(collection,
|
||||
defaultSegmentID,
|
||||
defaultPartitionID,
|
||||
defaultCollectionID,
|
||||
defaultDMLChannel,
|
||||
segmentTypeGrowing,
|
||||
defaultSegmentVersion,
|
||||
defaultSegmentStartPosition,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
has := segment.hasRawData(simpleFloatVecField.id)
|
||||
assert.True(t, has)
|
||||
})
|
||||
|
||||
t.Run("sealed", func(t *testing.T) {
|
||||
schema := genTestCollectionSchema()
|
||||
collection := newCollection(defaultCollectionID, schema)
|
||||
segment, err := newSegment(collection,
|
||||
defaultSegmentID,
|
||||
defaultPartitionID,
|
||||
defaultCollectionID,
|
||||
defaultDMLChannel,
|
||||
segmentTypeSealed,
|
||||
defaultSegmentVersion,
|
||||
defaultSegmentStartPosition,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
has := segment.hasRawData(simpleFloatVecField.id)
|
||||
assert.True(t, has)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -814,6 +814,16 @@ func GetSizeOfIDs(data *schemapb.IDs) int {
|
|||
return result
|
||||
}
|
||||
|
||||
func GetPKSize(fieldData *schemapb.FieldData) int {
|
||||
switch fieldData.GetType() {
|
||||
case schemapb.DataType_Int64:
|
||||
return len(fieldData.GetScalars().GetLongData().GetData())
|
||||
case schemapb.DataType_VarChar:
|
||||
return len(fieldData.GetScalars().GetStringData().GetData())
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func IsPrimaryFieldType(dataType schemapb.DataType) bool {
|
||||
if dataType == schemapb.DataType_Int64 || dataType == schemapb.DataType_VarChar {
|
||||
return true
|
||||
|
@ -849,6 +859,31 @@ func GetTS(i *internalpb.RetrieveResults, idx int64) uint64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
func GetData(field *schemapb.FieldData, idx int) interface{} {
|
||||
switch field.GetType() {
|
||||
case schemapb.DataType_Bool:
|
||||
return field.GetScalars().GetBoolData().GetData()[idx]
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
return field.GetScalars().GetIntData().GetData()[idx]
|
||||
case schemapb.DataType_Int64:
|
||||
return field.GetScalars().GetLongData().GetData()[idx]
|
||||
case schemapb.DataType_Float:
|
||||
return field.GetScalars().GetFloatData().GetData()[idx]
|
||||
case schemapb.DataType_Double:
|
||||
return field.GetScalars().GetDoubleData().GetData()[idx]
|
||||
case schemapb.DataType_VarChar:
|
||||
return field.GetScalars().GetStringData().GetData()[idx]
|
||||
case schemapb.DataType_FloatVector:
|
||||
dim := int(field.GetVectors().GetDim())
|
||||
return field.GetVectors().GetFloatVector().GetData()[idx*dim : (idx+1)*dim]
|
||||
case schemapb.DataType_BinaryVector:
|
||||
dim := int(field.GetVectors().GetDim())
|
||||
dataBytes := dim / 8
|
||||
return field.GetVectors().GetBinaryVector()[idx*dataBytes : (idx+1)*dataBytes]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AppendPKs(pks *schemapb.IDs, pk interface{}) {
|
||||
switch realPK := pk.(type) {
|
||||
case int64:
|
||||
|
|
|
@ -1111,3 +1111,69 @@ func TestMergeFieldData(t *testing.T) {
|
|||
|
||||
MergeFieldData([]*schemapb.FieldData{emptyField}, []*schemapb.FieldData{emptyField})
|
||||
}
|
||||
|
||||
func TestGetDataAndGetDataSize(t *testing.T) {
|
||||
const (
|
||||
Dim = 8
|
||||
fieldName = "filed-0"
|
||||
fieldID = 0
|
||||
)
|
||||
|
||||
BoolArray := []bool{true, false}
|
||||
Int8Array := []int8{1, 2}
|
||||
Int16Array := []int16{3, 4}
|
||||
Int32Array := []int32{5, 6}
|
||||
Int64Array := []int64{11, 22}
|
||||
FloatArray := []float32{1.0, 2.0}
|
||||
DoubleArray := []float64{11.0, 22.0}
|
||||
VarCharArray := []string{"a", "b"}
|
||||
BinaryVector := []byte{0x12, 0x34}
|
||||
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
|
||||
|
||||
boolData := genFieldData(fieldName, fieldID, schemapb.DataType_Bool, BoolArray, 1)
|
||||
int8Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int8, Int8Array, 1)
|
||||
int16Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int16, Int16Array, 1)
|
||||
int32Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int32, Int32Array, 1)
|
||||
int64Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int64, Int64Array, 1)
|
||||
floatData := genFieldData(fieldName, fieldID, schemapb.DataType_Float, FloatArray, 1)
|
||||
doubleData := genFieldData(fieldName, fieldID, schemapb.DataType_Double, DoubleArray, 1)
|
||||
varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1)
|
||||
binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim)
|
||||
floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim)
|
||||
invalidData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_None,
|
||||
}
|
||||
|
||||
t.Run("test GetPKSize", func(t *testing.T) {
|
||||
int64DataRes := GetPKSize(int64Data)
|
||||
varCharDataRes := GetPKSize(varCharData)
|
||||
assert.Equal(t, 2, int64DataRes)
|
||||
assert.Equal(t, 2, varCharDataRes)
|
||||
})
|
||||
|
||||
t.Run("test GetData", func(t *testing.T) {
|
||||
boolDataRes := GetData(boolData, 0)
|
||||
int8DataRes := GetData(int8Data, 0)
|
||||
int16DataRes := GetData(int16Data, 0)
|
||||
int32DataRes := GetData(int32Data, 0)
|
||||
int64DataRes := GetData(int64Data, 0)
|
||||
floatDataRes := GetData(floatData, 0)
|
||||
doubleDataRes := GetData(doubleData, 0)
|
||||
varCharDataRes := GetData(varCharData, 0)
|
||||
binVecDataRes := GetData(binVecData, 0)
|
||||
floatVecDataRes := GetData(floatVecData, 0)
|
||||
invalidDataRes := GetData(invalidData, 0)
|
||||
|
||||
assert.Equal(t, BoolArray[0], boolDataRes)
|
||||
assert.Equal(t, int32(Int8Array[0]), int8DataRes)
|
||||
assert.Equal(t, int32(Int16Array[0]), int16DataRes)
|
||||
assert.Equal(t, Int32Array[0], int32DataRes)
|
||||
assert.Equal(t, Int64Array[0], int64DataRes)
|
||||
assert.Equal(t, FloatArray[0], floatDataRes)
|
||||
assert.Equal(t, DoubleArray[0], doubleDataRes)
|
||||
assert.Equal(t, VarCharArray[0], varCharDataRes)
|
||||
assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes)
|
||||
assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes)
|
||||
assert.Nil(t, invalidDataRes)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -230,6 +230,7 @@ binary_support = ["BIN_FLAT", "BIN_IVF_FLAT"]
|
|||
delete_support = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]
|
||||
ivf = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]
|
||||
skip_pq = ["IVF_PQ"]
|
||||
float_metrics = ["L2", "IP"]
|
||||
binary_metrics = ["JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
||||
structure_metrics = ["SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ import random
|
|||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import decimal
|
||||
from decimal import Decimal, getcontext
|
||||
from time import sleep
|
||||
|
||||
from base.client_base import TestcaseBase
|
||||
|
@ -812,11 +814,7 @@ class TestCollectionSearchInvalid(TestcaseBase):
|
|||
log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name)
|
||||
collection_w.search(vectors[:default_nq], default_search_field,
|
||||
default_search_params, default_limit,
|
||||
default_search_exp, output_fields=output_fields,
|
||||
check_task=CheckTasks.err_res,
|
||||
check_items={"err_code": 1,
|
||||
"err_msg": "Search doesn't support "
|
||||
"vector field as output_fields"})
|
||||
default_search_exp, output_fields=output_fields)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]])
|
||||
|
@ -2644,6 +2642,33 @@ class TestCollectionSearch(TestcaseBase):
|
|||
assert len(res[0][0].entity._row_data) != 0
|
||||
assert default_int64_field_name in res[0][0].entity._row_data
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_search_with_output_vector_field(self, auto_id, _async):
|
||||
"""
|
||||
target: test search with output fields
|
||||
method: search with one output_field
|
||||
expected: search success
|
||||
"""
|
||||
# 1. initialize with data
|
||||
collection_w, _, _, insert_ids = self.init_collection_general(prefix, True,
|
||||
auto_id=auto_id)[0:4]
|
||||
# 2. search
|
||||
log.info("test_search_with_output_field: Searching collection %s" % collection_w.name)
|
||||
res = collection_w.search(vectors[:default_nq], default_search_field,
|
||||
default_search_params, default_limit,
|
||||
default_search_exp, _async=_async,
|
||||
output_fields=[field_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit,
|
||||
"_async": _async})[0]
|
||||
if _async:
|
||||
res.done()
|
||||
res = res.result()
|
||||
assert len(res[0][0].entity._row_data) != 0
|
||||
assert field_name in res[0][0].entity._row_data
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async):
|
||||
"""
|
||||
|
@ -2675,6 +2700,183 @@ class TestCollectionSearch(TestcaseBase):
|
|||
assert len(res[0][0].entity._row_data) != 0
|
||||
assert (default_int64_field_name and default_float_field_name) in res[0][0].entity._row_data
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("index, params",
|
||||
zip(ct.all_index_types[:6],
|
||||
ct.default_index_params[:6]))
|
||||
@pytest.mark.parametrize("metrics", ct.float_metrics)
|
||||
def test_search_output_field_vector_after_different_index_metrics(self, index, params, metrics):
|
||||
"""
|
||||
target: test search with output vector field after different index
|
||||
method: 1. create a collection and insert data
|
||||
2. create index and load
|
||||
3. search with output field vector
|
||||
4. check the result vectors should be equal to the inserted
|
||||
expected: search success
|
||||
"""
|
||||
# 1. create a collection and insert data
|
||||
collection_w = self.init_collection_general(prefix, is_index=True)[0]
|
||||
data = cf.gen_default_dataframe_data()
|
||||
collection_w.insert(data)
|
||||
|
||||
# 2. create index and load
|
||||
default_index = {"index_type": index, "params": params, "metric_type": metrics}
|
||||
collection_w.create_index(field_name, default_index)
|
||||
collection_w.load()
|
||||
|
||||
# 3. search with output field vector
|
||||
search_params = cf.gen_search_param(index, metrics)[0]
|
||||
res = collection_w.search(vectors[:1], default_search_field,
|
||||
search_params, default_limit, default_search_exp,
|
||||
output_fields=[field_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1,
|
||||
"limit": default_limit})[0]
|
||||
|
||||
# 4. check the result vectors should be equal to the inserted
|
||||
for _id in range(default_limit):
|
||||
for i in range(default_dim):
|
||||
vectorInsert = str(data[field_name][res[0][_id].id][i])[:7]
|
||||
vectorRes = str(res[0][_id].entity.float_vector[i])[:7]
|
||||
if vectorInsert != vectorRes:
|
||||
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
|
||||
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
|
||||
assert str(vectorInsert) == vectorRes
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip(reason="issue #23661")
|
||||
@pytest.mark.parametrize("index", ct.all_index_types[6:8])
|
||||
def test_search_output_field_vector_after_binary_index(self, index):
|
||||
"""
|
||||
target: test search with output vector field after binary index
|
||||
method: 1. create a collection and insert data
|
||||
2. create index and load
|
||||
3. search with output field vector
|
||||
4. check the result vectors should be equal to the inserted
|
||||
expected: search success
|
||||
"""
|
||||
# 1. create a collection and insert data
|
||||
collection_w = self.init_collection_general(prefix, is_binary=True, is_index=False)[0]
|
||||
data = cf.gen_default_binary_dataframe_data()[0]
|
||||
collection_w.insert(data)
|
||||
|
||||
# 2. create index and load
|
||||
default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "JACCARD"}
|
||||
collection_w.create_index(binary_field_name, default_index)
|
||||
collection_w.load()
|
||||
|
||||
# 3. search with output field vector
|
||||
search_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
|
||||
binary_vectors = cf.gen_binary_vectors(1, default_dim)[1]
|
||||
res = collection_w.search(binary_vectors, binary_field_name,
|
||||
ct.default_search_binary_params, 2, default_search_exp,
|
||||
output_fields=[binary_field_name])[0]
|
||||
|
||||
# 4. check the result vectors should be equal to the inserted
|
||||
log.info(res[0][0].id)
|
||||
log.info(res[0][0].entity.float_vector)
|
||||
log.info(data['binary_vector'][0])
|
||||
assert res[0][0].entity.binary_vector == data[binary_field_name][res[0][0].id]
|
||||
# log.info(data['float_vector'][1])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("dim", [32, 128, 768])
|
||||
def test_search_output_field_vector_with_different_dim(self, dim):
|
||||
"""
|
||||
target: test search with output vector field after binary index
|
||||
method: 1. create a collection and insert data
|
||||
2. create index and load
|
||||
3. search with output field vector
|
||||
4. check the result vectors should be equal to the inserted
|
||||
expected: search success
|
||||
"""
|
||||
# 1. create a collection and insert data
|
||||
collection_w = self.init_collection_general(prefix, is_index=True, dim=dim)[0]
|
||||
data = cf.gen_default_dataframe_data(dim=dim)
|
||||
collection_w.insert(data)
|
||||
|
||||
# 2. create index and load
|
||||
index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
|
||||
collection_w.create_index("float_vector", index_params)
|
||||
collection_w.load()
|
||||
|
||||
# 3. search with output field vector
|
||||
vectors = cf.gen_vectors(default_nq, dim=dim)
|
||||
res = collection_w.search(vectors[:default_nq], default_search_field,
|
||||
default_search_params, default_limit, default_search_exp,
|
||||
output_fields=[field_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"limit": default_limit})[0]
|
||||
|
||||
# 4. check the result vectors should be equal to the inserted
|
||||
for i in range(default_limit):
|
||||
assert len(res[0][i].entity.float_vector) == len(data[field_name][res[0][i].id])
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_search_output_vector_field_and_scalar_field(self):
|
||||
"""
|
||||
target: test search with output vector field and scalar field
|
||||
method: 1. initialize a collection
|
||||
2. search with output field vector
|
||||
3. check no field missing
|
||||
expected: search success
|
||||
"""
|
||||
# 1. initialize a collection
|
||||
collection_w = self.init_collection_general(prefix, True)[0]
|
||||
|
||||
# 2. search with output field vector
|
||||
res = collection_w.search(vectors[:1], default_search_field,
|
||||
default_search_params, default_limit, default_search_exp,
|
||||
output_fields=[default_float_field_name,
|
||||
default_string_field_name,
|
||||
default_search_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1,
|
||||
"limit": default_limit})[0]
|
||||
|
||||
# 3. check the result
|
||||
assert default_float_field_name, default_string_field_name in res[0][0].entity._row_data
|
||||
assert default_search_field in res[0][0].entity._row_data
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_search_output_field_vector_with_partition(self):
|
||||
"""
|
||||
target: test search with output vector field
|
||||
method: 1. create a collection and insert data
|
||||
2. create index and load
|
||||
3. search with output field vector
|
||||
4. check the result vectors should be equal to the inserted
|
||||
expected: search success
|
||||
"""
|
||||
# 1. create a collection and insert data
|
||||
collection_w = self.init_collection_general(prefix, is_index=True)[0]
|
||||
partition_w = self.init_partition_wrap(collection_w)
|
||||
data = cf.gen_default_dataframe_data()
|
||||
partition_w.insert(data)
|
||||
|
||||
# 2. create index and load
|
||||
collection_w.create_index(field_name, default_index_params)
|
||||
collection_w.load()
|
||||
|
||||
# 3. search with output field vector
|
||||
res = partition_w.search(vectors[:1], default_search_field,
|
||||
default_search_params, default_limit, default_search_exp,
|
||||
output_fields=[field_name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1,
|
||||
"limit": default_limit})[0]
|
||||
|
||||
# 4. check the result vectors should be equal to the inserted
|
||||
for _id in range(default_limit):
|
||||
for i in range(default_dim):
|
||||
vectorInsert = str(data[field_name][res[0][_id].id][i])[:7]
|
||||
vectorRes = str(res[0][_id].entity.float_vector[i])[:7]
|
||||
if vectorInsert != vectorRes:
|
||||
getcontext().rounding = getattr(decimal, 'ROUND_HALF_UP')
|
||||
vectorInsert = Decimal(data[field_name][res[0][_id].id][i]).quantize(Decimal('0.00000'))
|
||||
assert str(vectorInsert) == vectorRes
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("output_fields", [["*"], ["*", default_float_field_name]])
|
||||
def test_search_with_output_field_wildcard(self, output_fields, auto_id, _async, enable_dynamic_field):
|
||||
|
@ -4978,6 +5180,40 @@ class TestsearchDiskann(TestcaseBase):
|
|||
"limit": limit,
|
||||
"_async": _async})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.xfail(reason="issue #23672")
|
||||
def test_search_diskann_search_list_up_to_min(self, dim, auto_id, _async):
|
||||
"""
|
||||
target: test search diskann index when search_list up to min
|
||||
method: 1.create collection , insert data, primary_field is int field
|
||||
2.create diskann index , then load
|
||||
3.search
|
||||
expected: search successfully
|
||||
"""
|
||||
# 1. initialize with data
|
||||
collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id,
|
||||
dim=dim, is_index=False)[0:4]
|
||||
|
||||
# 2. create index
|
||||
default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}}
|
||||
collection_w.create_index(ct.default_float_vec_field_name, default_index)
|
||||
collection_w.load()
|
||||
|
||||
search_params = {"metric_type": "L2", "params": {"k": 200, "search_list": 201}}
|
||||
search_vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)]
|
||||
output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name]
|
||||
collection_w.search(search_vectors[:default_nq], default_search_field,
|
||||
search_params, default_limit,
|
||||
default_search_exp,
|
||||
output_fields=output_fields,
|
||||
_async=_async,
|
||||
travel_timestamp=0,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": default_nq,
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit,
|
||||
"_async": _async})
|
||||
|
||||
|
||||
class TestCollectionSearchJSON(TestcaseBase):
|
||||
""" Test case of search interface """
|
||||
|
|
Loading…
Reference in New Issue