enhance: add search params in search request in restful(#36304) (#37673)

pr: #36304 
pr: #36714 
pr: #36448

---------

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: zhuwenxing <wenxing.zhu@zilliz.com>
pull/37731/head
smellthemoon 2024-11-15 17:54:30 +08:00 committed by GitHub
parent 4e11fe7adf
commit b3e6482367
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 315 additions and 117 deletions

View File

@ -20,9 +20,10 @@ import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/stretchr/testify/assert"
)
func TestResizePools(t *testing.T) {

View File

@ -944,45 +944,35 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
})
}
func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
params := map[string]interface{}{ // auto generated mapping
"level": int(commonpb.ConsistencyLevel_Bounded),
}
if reqParams != nil {
radius, radiusOk := reqParams[ParamRadius]
rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter]
if rangeFilterOk {
if !radiusOk {
log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter)
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params",
})
return nil, merr.ErrIncorrectParameterFormat
}
params[ParamRangeFilter] = rangeFilter
}
if radiusOk {
params[ParamRadius] = radius
}
}
bs, _ := json.Marshal(params)
searchParams := []*commonpb.KeyValuePair{
{Key: Params, Value: string(bs)},
}
return searchParams, nil
func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair {
var searchParams []*commonpb.KeyValuePair
bs, _ := json.Marshal(reqSearchParams.Params)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)})
// need to exposure ParamRoundDecimal in req?
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
return searchParams
}
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*SearchReqV2)
req := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
UseDefaultConsistency: true,
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
}
var err error
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
})
return nil, err
}
c.Set(ContextRequest, req)
@ -990,15 +980,12 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
if err != nil {
return nil, err
}
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)
if err != nil {
return nil, err
}
searchParams := generateSearchParams(ctx, c, httpReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
body, _ := c.Get(gin.BodyBytesKey)
placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField)
if err != nil {
@ -1044,6 +1031,16 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
Requests: []*milvuspb.SearchRequest{},
OutputFields: httpReq.OutputFields,
}
var err error
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
})
return nil, err
}
c.Set(ContextRequest, req)
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
@ -1053,15 +1050,11 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
if err != nil {
return nil, err
}
searchParams := generateSearchParams(ctx, c, subReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
@ -1072,15 +1065,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
return nil, err
}
searchReq := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
UseDefaultConsistency: true,
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: placeholderGroup,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
}
req.Requests = append(req.Requests, searchReq)
}

View File

@ -1424,7 +1424,7 @@ func TestSearchV2(t *testing.T) {
Schema: generateCollectionSchema(schemapb.DataType_Int64),
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(12)
}, nil).Times(11)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{
TopK: int64(3),
OutputFields: outputFields,
@ -1465,6 +1465,12 @@ func TestSearchV2(t *testing.T) {
Status: &StatusSuccess,
}, nil).Times(10)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
@ -1473,7 +1479,7 @@ func TestSearchV2(t *testing.T) {
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
@ -1481,8 +1487,8 @@ func TestSearchV2(t *testing.T) {
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
errMsg: "can only accept json format request, error: invalid search params",
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
@ -1556,6 +1562,17 @@ func TestSearchV2(t *testing.T) {
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "consistencyLevel":"unknown","rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
errCode: 1100, // ErrParameterInvalid
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
@ -1604,6 +1621,24 @@ func TestSearchV2(t *testing.T) {
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "unknown"}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
errCode: 1100, // ErrParameterInvalid
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {

View File

@ -141,18 +141,28 @@ type CollectionDataReq struct {
func (req *CollectionDataReq) GetDbName() string { return req.DbName }
type searchParams struct {
// not use metricType any more, just for compatibility
MetricType string `json:"metricType"`
Params map[string]interface{} `json:"params"`
IgnoreGrowing bool `json:"ignoreGrowing"`
}
type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
Params map[string]float64 `json:"params"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
// not use Params any more, just for compatibility
Params map[string]float64 `json:"params"`
}
func (req *SearchReqV2) GetDbName() string { return req.DbName }
@ -163,25 +173,25 @@ type Rand struct {
}
type SubSearchReq struct {
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
IgnoreGrowing bool `json:"ignoreGrowing"`
Params map[string]float64 `json:"params"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"searchParams"`
}
type HybridSearchReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
OutputFields []string `json:"outputFields"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
OutputFields []string `json:"outputFields"`
ConsistencyLevel string `json:"consistencyLevel"`
}
func (req *HybridSearchReq) GetDbName() string { return req.DbName }

View File

@ -1314,3 +1314,15 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro
}
return params, nil
}
func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLevel, bool, error) {
if reqConsistencyLevel != "" {
level, ok := commonpb.ConsistencyLevel_value[reqConsistencyLevel]
if !ok {
return 0, false, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", reqConsistencyLevel))
}
return commonpb.ConsistencyLevel(level), false, nil
}
// ConsistencyLevel_Bounded default in PyMilvus
return commonpb.ConsistencyLevel_Bounded, true, nil
}

View File

@ -1406,3 +1406,16 @@ func TestConvertToExtraParams(t *testing.T) {
}
}
}
func TestConvertConsistencyLevel(t *testing.T) {
consistencyLevel, useDefaultConsistency, err := convertConsistencyLevel("")
assert.Equal(t, nil, err)
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Bounded)
assert.Equal(t, true, useDefaultConsistency)
consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("Strong")
assert.Equal(t, nil, err)
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Strong)
assert.Equal(t, false, useDefaultConsistency)
_, _, err = convertConsistencyLevel("test")
assert.NotNil(t, err)
}

View File

@ -128,6 +128,8 @@ const (
BitmapCardinalityLimitKey = "bitmap_cardinality_limit"
IsSparseKey = "is_sparse"
AutoIndexName = "AUTOINDEX"
IgnoreGrowing = "ignore_growing"
ConsistencyLevel = "consistency_level"
)
// Collection properties key

View File

@ -101,7 +101,8 @@ class TestBase(Base):
batch_size = batch_size
batch = nb // batch_size
remainder = nb % batch_size
data = []
full_data = []
insert_ids = []
for i in range(batch):
nb = batch_size
@ -116,6 +117,7 @@ class TestBase(Base):
assert rsp['code'] == 0
if return_insert_id:
insert_ids.extend(rsp['data']['insertIds'])
full_data.extend(data)
# insert remainder data
if remainder:
nb = remainder
@ -128,10 +130,11 @@ class TestBase(Base):
assert rsp['code'] == 0
if return_insert_id:
insert_ids.extend(rsp['data']['insertIds'])
full_data.extend(data)
if return_insert_id:
return schema_payload, data, insert_ids
return schema_payload, full_data, insert_ids
return schema_payload, data
return schema_payload, full_data
def wait_collection_load_completed(self, name):
t0 = time.time()

View File

@ -4,8 +4,10 @@ import numpy as np
import sys
import json
import time
import utils.utils
from utils import constant
from utils.utils import gen_collection_name
from utils.utils import gen_collection_name, get_sorted_distance
from utils.util_log import test_log as logger
import pytest
from base.testbase import TestBase
@ -921,7 +923,6 @@ class TestUpsertVector(TestBase):
@pytest.mark.L0
class TestSearchVector(TestBase):
@pytest.mark.parametrize("insert_round", [1])
@pytest.mark.parametrize("auto_id", [True])
@pytest.mark.parametrize("is_partition_key", [True])
@ -1010,14 +1011,7 @@ class TestSearchVector(TestBase):
"filter": "word_count > 100",
"groupingField": "user_id",
"outputFields": ["*"],
"searchParams": {
"metricType": "COSINE",
"params": {
"radius": "0.1",
"range_filter": "0.8"
}
},
"limit": 100,
"limit": 100
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0
@ -1032,8 +1026,9 @@ class TestSearchVector(TestBase):
@pytest.mark.parametrize("nb", [3000])
@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("nq", [1, 2])
@pytest.mark.parametrize("metric_type", ['COSINE', "L2", "IP"])
def test_search_vector_with_float_vector_datatype(self, nb, dim, insert_round, auto_id,
is_partition_key, enable_dynamic_schema, nq):
is_partition_key, enable_dynamic_schema, nq, metric_type):
"""
Insert a vector with a simple payload
"""
@ -1054,7 +1049,7 @@ class TestSearchVector(TestBase):
]
},
"indexParams": [
{"fieldName": "float_vector", "indexName": "float_vector", "metricType": "COSINE"},
{"fieldName": "float_vector", "indexName": "float_vector", "metricType": metric_type},
]
}
rsp = self.collection_client.collection_create(payload)
@ -1098,13 +1093,6 @@ class TestSearchVector(TestBase):
"filter": "word_count > 100",
"groupingField": "user_id",
"outputFields": ["*"],
"searchParams": {
"metricType": "COSINE",
"params": {
"radius": "0.1",
"range_filter": "0.8"
}
},
"limit": 100,
}
rsp = self.vector_client.vector_search(payload)
@ -1225,7 +1213,8 @@ class TestSearchVector(TestBase):
@pytest.mark.parametrize("enable_dynamic_schema", [True])
@pytest.mark.parametrize("nb", [3000])
@pytest.mark.parametrize("dim", [128])
def test_search_vector_with_binary_vector_datatype(self, nb, dim, insert_round, auto_id,
@pytest.mark.parametrize("metric_type", ['HAMMING'])
def test_search_vector_with_binary_vector_datatype(self, metric_type, nb, dim, insert_round, auto_id,
is_partition_key, enable_dynamic_schema):
"""
Insert a vector with a simple payload
@ -1247,7 +1236,7 @@ class TestSearchVector(TestBase):
]
},
"indexParams": [
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING",
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": metric_type,
"params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
]
}
@ -1298,13 +1287,6 @@ class TestSearchVector(TestBase):
"data": [gen_vector(datatype="BinaryVector", dim=dim)],
"filter": "word_count > 100",
"outputFields": ["*"],
"searchParams": {
"metricType": "HAMMING",
"params": {
"radius": "0.1",
"range_filter": "0.8"
}
},
"limit": 100,
}
rsp = self.vector_client.vector_search(payload)
@ -1546,6 +1528,130 @@ class TestSearchVector(TestBase):
if "like" in varchar_expr:
assert name.startswith(prefix)
@pytest.mark.parametrize("consistency_level", ["Strong", "Bounded", "Eventually", "Session"])
def test_search_vector_with_consistency_level(self, consistency_level):
"""
Search a vector with different consistency level
"""
name = gen_collection_name()
self.name = name
nb = 200
dim = 128
limit = 100
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
names = []
for item in data:
names.append(item.get("name"))
names.sort()
logger.info(f"names: {names}")
mid = len(names) // 2
prefix = names[mid][0:2]
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
payload = {
"collectionName": name,
"data": [vector_to_search],
"outputFields": output_fields,
"limit": limit,
"offset": 0,
"consistencyLevel": consistency_level
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) == limit
@pytest.mark.parametrize("metric_type", ["L2", "COSINE", "IP"])
def test_search_vector_with_range_search(self, metric_type):
"""
Search a vector with range search with different metric type
"""
name = gen_collection_name()
self.name = name
nb = 3000
dim = 128
limit = 100
schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
training_data = [item[vector_field] for item in data]
distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type)
r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct
if metric_type == "L2":
r1, r2 = r2, r1
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
payload = {
"collectionName": name,
"data": [vector_to_search],
"outputFields": output_fields,
"limit": limit,
"offset": 0,
"searchParams": {
"params": {
"radius": r1,
"range_filter": r2,
}
}
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0
res = rsp['data']
logger.info(f"res: {len(res)}")
assert len(res) == limit
for item in res:
distance = item.get("distance")
if metric_type == "L2":
assert r1 > distance > r2
else:
assert r1 < distance < r2
@pytest.mark.parametrize("ignore_growing", [True, False])
def test_search_vector_with_ignore_growing(self, ignore_growing):
"""
Search a vector with range search with different metric type
"""
name = gen_collection_name()
self.name = name
metric_type = "COSINE"
nb = 1000
dim = 128
limit = 100
schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
vector_field = schema_payload.get("vectorField")
# search data
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
training_data = [item[vector_field] for item in data]
distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type)
r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct
if metric_type == "L2":
r1, r2 = r2, r1
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
payload = {
"collectionName": name,
"data": [vector_to_search],
"outputFields": output_fields,
"limit": limit,
"offset": 0,
"searchParams": {
"ignoreGrowing": ignore_growing
}
}
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0
res = rsp['data']
logger.info(f"res: {len(res)}")
if ignore_growing is True:
assert len(res) == 0
else:
assert len(res) == limit
@pytest.mark.L1
class TestSearchVectorNegative(TestBase):

View File

@ -10,7 +10,7 @@ import base64
import requests
from loguru import logger
import datetime
from sklearn.metrics import pairwise_distances
fake = Faker()
rng = np.random.default_rng()
@ -240,4 +240,28 @@ def get_all_fields_by_data(data, exclude_fields=None):
return list(fields)
def ip_distance(x, y):
return np.dot(x, y)
def cosine_distance(u, v, epsilon=1e-8):
dot_product = np.dot(u, v)
norm_u = np.linalg.norm(u)
norm_v = np.linalg.norm(v)
return dot_product / (max(norm_u * norm_v, epsilon))
def l2_distance(u, v):
return np.sum((u - v) ** 2)
def get_sorted_distance(train_emb, test_emb, metric_type):
milvus_sklearn_metric_map = {
"L2": l2_distance,
"COSINE": cosine_distance,
"IP": ip_distance
}
distance = pairwise_distances(train_emb, Y=test_emb, metric=milvus_sklearn_metric_map[metric_type], n_jobs=-1)
distance = np.array(distance.T, order='C', dtype=np.float16)
distance_sorted = np.sort(distance, axis=1).tolist()
return distance_sorted