mirror of https://github.com/milvus-io/milvus.git
pr: #36304 pr: #36714 pr: #36448 --------- Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/37731/head
parent
4e11fe7adf
commit
b3e6482367
|
@ -20,9 +20,10 @@ import (
|
|||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/config"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResizePools(t *testing.T) {
|
||||
|
|
|
@ -944,45 +944,35 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
|
|||
})
|
||||
}
|
||||
|
||||
func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
|
||||
params := map[string]interface{}{ // auto generated mapping
|
||||
"level": int(commonpb.ConsistencyLevel_Bounded),
|
||||
}
|
||||
if reqParams != nil {
|
||||
radius, radiusOk := reqParams[ParamRadius]
|
||||
rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter]
|
||||
if rangeFilterOk {
|
||||
if !radiusOk {
|
||||
log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter)
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
|
||||
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params",
|
||||
})
|
||||
return nil, merr.ErrIncorrectParameterFormat
|
||||
}
|
||||
params[ParamRangeFilter] = rangeFilter
|
||||
}
|
||||
if radiusOk {
|
||||
params[ParamRadius] = radius
|
||||
}
|
||||
}
|
||||
bs, _ := json.Marshal(params)
|
||||
searchParams := []*commonpb.KeyValuePair{
|
||||
{Key: Params, Value: string(bs)},
|
||||
}
|
||||
return searchParams, nil
|
||||
func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair {
|
||||
var searchParams []*commonpb.KeyValuePair
|
||||
bs, _ := json.Marshal(reqSearchParams.Params)
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)})
|
||||
// need to exposure ParamRoundDecimal in req?
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
|
||||
return searchParams
|
||||
}
|
||||
|
||||
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*SearchReqV2)
|
||||
req := &milvuspb.SearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: httpReq.Filter,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
UseDefaultConsistency: true,
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: httpReq.Filter,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
}
|
||||
var err error
|
||||
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
|
||||
|
@ -990,15 +980,12 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
searchParams := generateSearchParams(ctx, c, httpReq.SearchParams)
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField)
|
||||
if err != nil {
|
||||
|
@ -1044,6 +1031,16 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
|
|||
Requests: []*milvuspb.SearchRequest{},
|
||||
OutputFields: httpReq.OutputFields,
|
||||
}
|
||||
var err error
|
||||
req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
|
||||
collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName)
|
||||
|
@ -1053,15 +1050,11 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
|
|||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
|
||||
for i, subReq := range httpReq.Search {
|
||||
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchParams := generateSearchParams(ctx, c, subReq.SearchParams)
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
|
||||
placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err))
|
||||
|
@ -1072,15 +1065,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
|
|||
return nil, err
|
||||
}
|
||||
searchReq := &milvuspb.SearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: subReq.Filter,
|
||||
PlaceholderGroup: placeholderGroup,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
SearchParams: searchParams,
|
||||
UseDefaultConsistency: true,
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: subReq.Filter,
|
||||
PlaceholderGroup: placeholderGroup,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
SearchParams: searchParams,
|
||||
}
|
||||
req.Requests = append(req.Requests, searchReq)
|
||||
}
|
||||
|
|
|
@ -1424,7 +1424,7 @@ func TestSearchV2(t *testing.T) {
|
|||
Schema: generateCollectionSchema(schemapb.DataType_Int64),
|
||||
ShardsNum: ShardNumDefault,
|
||||
Status: &StatusSuccess,
|
||||
}, nil).Times(12)
|
||||
}, nil).Times(11)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{
|
||||
TopK: int64(3),
|
||||
OutputFields: outputFields,
|
||||
|
@ -1465,6 +1465,12 @@ func TestSearchV2(t *testing.T) {
|
|||
Status: &StatusSuccess,
|
||||
}, nil).Times(10)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{
|
||||
Code: 1100,
|
||||
Reason: "mock",
|
||||
},
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
queryTestCases := []requestBodyTestCase{}
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
|
@ -1473,7 +1479,7 @@ func TestSearchV2(t *testing.T) {
|
|||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
|
@ -1481,8 +1487,8 @@ func TestSearchV2(t *testing.T) {
|
|||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`),
|
||||
errMsg: "can only accept json format request, error: invalid search params",
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`),
|
||||
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool",
|
||||
errCode: 1801, // ErrIncorrectParameterFormat
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
|
@ -1556,6 +1562,17 @@ func TestSearchV2(t *testing.T) {
|
|||
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
|
||||
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
|
||||
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
|
||||
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
|
||||
`], "consistencyLevel":"unknown","rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: AdvancedSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
|
||||
|
@ -1604,6 +1621,24 @@ func TestSearchV2(t *testing.T) {
|
|||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`),
|
||||
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}",
|
||||
errCode: 1801, // ErrIncorrectParameterFormat
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "unknown"}`),
|
||||
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
|
||||
errMsg: "mock",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
|
||||
for _, testcase := range queryTestCases {
|
||||
t.Run(testcase.path, func(t *testing.T) {
|
||||
|
|
|
@ -141,18 +141,28 @@ type CollectionDataReq struct {
|
|||
|
||||
func (req *CollectionDataReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type searchParams struct {
|
||||
// not use metricType any more, just for compatibility
|
||||
MetricType string `json:"metricType"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
IgnoreGrowing bool `json:"ignoreGrowing"`
|
||||
}
|
||||
|
||||
type SearchReqV2 struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
Data []interface{} `json:"data" binding:"required"`
|
||||
AnnsField string `json:"annsField"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
Params map[string]float64 `json:"params"`
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
Data []interface{} `json:"data" binding:"required"`
|
||||
AnnsField string `json:"annsField"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
SearchParams searchParams `json:"searchParams"`
|
||||
ConsistencyLevel string `json:"consistencyLevel"`
|
||||
// not use Params any more, just for compatibility
|
||||
Params map[string]float64 `json:"params"`
|
||||
}
|
||||
|
||||
func (req *SearchReqV2) GetDbName() string { return req.DbName }
|
||||
|
@ -163,25 +173,25 @@ type Rand struct {
|
|||
}
|
||||
|
||||
type SubSearchReq struct {
|
||||
Data []interface{} `json:"data" binding:"required"`
|
||||
AnnsField string `json:"annsField"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
MetricType string `json:"metricType"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
IgnoreGrowing bool `json:"ignoreGrowing"`
|
||||
Params map[string]float64 `json:"params"`
|
||||
Data []interface{} `json:"data" binding:"required"`
|
||||
AnnsField string `json:"annsField"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
MetricType string `json:"metricType"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
SearchParams searchParams `json:"searchParams"`
|
||||
}
|
||||
|
||||
type HybridSearchReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Search []SubSearchReq `json:"search"`
|
||||
Rerank Rand `json:"rerank"`
|
||||
Limit int32 `json:"limit"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Search []SubSearchReq `json:"search"`
|
||||
Rerank Rand `json:"rerank"`
|
||||
Limit int32 `json:"limit"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
ConsistencyLevel string `json:"consistencyLevel"`
|
||||
}
|
||||
|
||||
func (req *HybridSearchReq) GetDbName() string { return req.DbName }
|
||||
|
|
|
@ -1314,3 +1314,15 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro
|
|||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLevel, bool, error) {
|
||||
if reqConsistencyLevel != "" {
|
||||
level, ok := commonpb.ConsistencyLevel_value[reqConsistencyLevel]
|
||||
if !ok {
|
||||
return 0, false, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", reqConsistencyLevel))
|
||||
}
|
||||
return commonpb.ConsistencyLevel(level), false, nil
|
||||
}
|
||||
// ConsistencyLevel_Bounded default in PyMilvus
|
||||
return commonpb.ConsistencyLevel_Bounded, true, nil
|
||||
}
|
||||
|
|
|
@ -1406,3 +1406,16 @@ func TestConvertToExtraParams(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertConsistencyLevel(t *testing.T) {
|
||||
consistencyLevel, useDefaultConsistency, err := convertConsistencyLevel("")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Bounded)
|
||||
assert.Equal(t, true, useDefaultConsistency)
|
||||
consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("Strong")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Strong)
|
||||
assert.Equal(t, false, useDefaultConsistency)
|
||||
_, _, err = convertConsistencyLevel("test")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
|
|
@ -128,6 +128,8 @@ const (
|
|||
BitmapCardinalityLimitKey = "bitmap_cardinality_limit"
|
||||
IsSparseKey = "is_sparse"
|
||||
AutoIndexName = "AUTOINDEX"
|
||||
IgnoreGrowing = "ignore_growing"
|
||||
ConsistencyLevel = "consistency_level"
|
||||
)
|
||||
|
||||
// Collection properties key
|
||||
|
|
|
@ -101,7 +101,8 @@ class TestBase(Base):
|
|||
batch_size = batch_size
|
||||
batch = nb // batch_size
|
||||
remainder = nb % batch_size
|
||||
data = []
|
||||
|
||||
full_data = []
|
||||
insert_ids = []
|
||||
for i in range(batch):
|
||||
nb = batch_size
|
||||
|
@ -116,6 +117,7 @@ class TestBase(Base):
|
|||
assert rsp['code'] == 0
|
||||
if return_insert_id:
|
||||
insert_ids.extend(rsp['data']['insertIds'])
|
||||
full_data.extend(data)
|
||||
# insert remainder data
|
||||
if remainder:
|
||||
nb = remainder
|
||||
|
@ -128,10 +130,11 @@ class TestBase(Base):
|
|||
assert rsp['code'] == 0
|
||||
if return_insert_id:
|
||||
insert_ids.extend(rsp['data']['insertIds'])
|
||||
full_data.extend(data)
|
||||
if return_insert_id:
|
||||
return schema_payload, data, insert_ids
|
||||
return schema_payload, full_data, insert_ids
|
||||
|
||||
return schema_payload, data
|
||||
return schema_payload, full_data
|
||||
|
||||
def wait_collection_load_completed(self, name):
|
||||
t0 = time.time()
|
||||
|
|
|
@ -4,8 +4,10 @@ import numpy as np
|
|||
import sys
|
||||
import json
|
||||
import time
|
||||
|
||||
import utils.utils
|
||||
from utils import constant
|
||||
from utils.utils import gen_collection_name
|
||||
from utils.utils import gen_collection_name, get_sorted_distance
|
||||
from utils.util_log import test_log as logger
|
||||
import pytest
|
||||
from base.testbase import TestBase
|
||||
|
@ -921,7 +923,6 @@ class TestUpsertVector(TestBase):
|
|||
@pytest.mark.L0
|
||||
class TestSearchVector(TestBase):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_round", [1])
|
||||
@pytest.mark.parametrize("auto_id", [True])
|
||||
@pytest.mark.parametrize("is_partition_key", [True])
|
||||
|
@ -1010,14 +1011,7 @@ class TestSearchVector(TestBase):
|
|||
"filter": "word_count > 100",
|
||||
"groupingField": "user_id",
|
||||
"outputFields": ["*"],
|
||||
"searchParams": {
|
||||
"metricType": "COSINE",
|
||||
"params": {
|
||||
"radius": "0.1",
|
||||
"range_filter": "0.8"
|
||||
}
|
||||
},
|
||||
"limit": 100,
|
||||
"limit": 100
|
||||
}
|
||||
rsp = self.vector_client.vector_search(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
@ -1032,8 +1026,9 @@ class TestSearchVector(TestBase):
|
|||
@pytest.mark.parametrize("nb", [3000])
|
||||
@pytest.mark.parametrize("dim", [128])
|
||||
@pytest.mark.parametrize("nq", [1, 2])
|
||||
@pytest.mark.parametrize("metric_type", ['COSINE', "L2", "IP"])
|
||||
def test_search_vector_with_float_vector_datatype(self, nb, dim, insert_round, auto_id,
|
||||
is_partition_key, enable_dynamic_schema, nq):
|
||||
is_partition_key, enable_dynamic_schema, nq, metric_type):
|
||||
"""
|
||||
Insert a vector with a simple payload
|
||||
"""
|
||||
|
@ -1054,7 +1049,7 @@ class TestSearchVector(TestBase):
|
|||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "float_vector", "indexName": "float_vector", "metricType": "COSINE"},
|
||||
{"fieldName": "float_vector", "indexName": "float_vector", "metricType": metric_type},
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
|
@ -1098,13 +1093,6 @@ class TestSearchVector(TestBase):
|
|||
"filter": "word_count > 100",
|
||||
"groupingField": "user_id",
|
||||
"outputFields": ["*"],
|
||||
"searchParams": {
|
||||
"metricType": "COSINE",
|
||||
"params": {
|
||||
"radius": "0.1",
|
||||
"range_filter": "0.8"
|
||||
}
|
||||
},
|
||||
"limit": 100,
|
||||
}
|
||||
rsp = self.vector_client.vector_search(payload)
|
||||
|
@ -1225,7 +1213,8 @@ class TestSearchVector(TestBase):
|
|||
@pytest.mark.parametrize("enable_dynamic_schema", [True])
|
||||
@pytest.mark.parametrize("nb", [3000])
|
||||
@pytest.mark.parametrize("dim", [128])
|
||||
def test_search_vector_with_binary_vector_datatype(self, nb, dim, insert_round, auto_id,
|
||||
@pytest.mark.parametrize("metric_type", ['HAMMING'])
|
||||
def test_search_vector_with_binary_vector_datatype(self, metric_type, nb, dim, insert_round, auto_id,
|
||||
is_partition_key, enable_dynamic_schema):
|
||||
"""
|
||||
Insert a vector with a simple payload
|
||||
|
@ -1247,7 +1236,7 @@ class TestSearchVector(TestBase):
|
|||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING",
|
||||
{"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": metric_type,
|
||||
"params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}
|
||||
]
|
||||
}
|
||||
|
@ -1298,13 +1287,6 @@ class TestSearchVector(TestBase):
|
|||
"data": [gen_vector(datatype="BinaryVector", dim=dim)],
|
||||
"filter": "word_count > 100",
|
||||
"outputFields": ["*"],
|
||||
"searchParams": {
|
||||
"metricType": "HAMMING",
|
||||
"params": {
|
||||
"radius": "0.1",
|
||||
"range_filter": "0.8"
|
||||
}
|
||||
},
|
||||
"limit": 100,
|
||||
}
|
||||
rsp = self.vector_client.vector_search(payload)
|
||||
|
@ -1546,6 +1528,130 @@ class TestSearchVector(TestBase):
|
|||
if "like" in varchar_expr:
|
||||
assert name.startswith(prefix)
|
||||
|
||||
@pytest.mark.parametrize("consistency_level", ["Strong", "Bounded", "Eventually", "Session"])
|
||||
def test_search_vector_with_consistency_level(self, consistency_level):
|
||||
"""
|
||||
Search a vector with different consistency level
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
self.name = name
|
||||
nb = 200
|
||||
dim = 128
|
||||
limit = 100
|
||||
schema_payload, data = self.init_collection(name, dim=dim, nb=nb)
|
||||
names = []
|
||||
for item in data:
|
||||
names.append(item.get("name"))
|
||||
names.sort()
|
||||
logger.info(f"names: {names}")
|
||||
mid = len(names) // 2
|
||||
prefix = names[mid][0:2]
|
||||
vector_field = schema_payload.get("vectorField")
|
||||
# search data
|
||||
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"data": [vector_to_search],
|
||||
"outputFields": output_fields,
|
||||
"limit": limit,
|
||||
"offset": 0,
|
||||
"consistencyLevel": consistency_level
|
||||
}
|
||||
rsp = self.vector_client.vector_search(payload)
|
||||
assert rsp['code'] == 0
|
||||
res = rsp['data']
|
||||
logger.info(f"res: {len(res)}")
|
||||
assert len(res) == limit
|
||||
|
||||
@pytest.mark.parametrize("metric_type", ["L2", "COSINE", "IP"])
|
||||
def test_search_vector_with_range_search(self, metric_type):
|
||||
"""
|
||||
Search a vector with range search with different metric type
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
self.name = name
|
||||
nb = 3000
|
||||
dim = 128
|
||||
limit = 100
|
||||
schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
|
||||
vector_field = schema_payload.get("vectorField")
|
||||
# search data
|
||||
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||
training_data = [item[vector_field] for item in data]
|
||||
distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type)
|
||||
r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct
|
||||
if metric_type == "L2":
|
||||
r1, r2 = r2, r1
|
||||
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"data": [vector_to_search],
|
||||
"outputFields": output_fields,
|
||||
"limit": limit,
|
||||
"offset": 0,
|
||||
"searchParams": {
|
||||
"params": {
|
||||
"radius": r1,
|
||||
"range_filter": r2,
|
||||
}
|
||||
}
|
||||
}
|
||||
rsp = self.vector_client.vector_search(payload)
|
||||
assert rsp['code'] == 0
|
||||
res = rsp['data']
|
||||
logger.info(f"res: {len(res)}")
|
||||
assert len(res) == limit
|
||||
for item in res:
|
||||
distance = item.get("distance")
|
||||
if metric_type == "L2":
|
||||
assert r1 > distance > r2
|
||||
else:
|
||||
assert r1 < distance < r2
|
||||
|
||||
@pytest.mark.parametrize("ignore_growing", [True, False])
|
||||
def test_search_vector_with_ignore_growing(self, ignore_growing):
|
||||
"""
|
||||
Search a vector with range search with different metric type
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
self.name = name
|
||||
metric_type = "COSINE"
|
||||
nb = 1000
|
||||
dim = 128
|
||||
limit = 100
|
||||
schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type)
|
||||
vector_field = schema_payload.get("vectorField")
|
||||
# search data
|
||||
vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||
training_data = [item[vector_field] for item in data]
|
||||
distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type)
|
||||
r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct
|
||||
if metric_type == "L2":
|
||||
r1, r2 = r2, r1
|
||||
output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field])
|
||||
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"data": [vector_to_search],
|
||||
"outputFields": output_fields,
|
||||
"limit": limit,
|
||||
"offset": 0,
|
||||
"searchParams": {
|
||||
"ignoreGrowing": ignore_growing
|
||||
|
||||
}
|
||||
}
|
||||
rsp = self.vector_client.vector_search(payload)
|
||||
assert rsp['code'] == 0
|
||||
res = rsp['data']
|
||||
logger.info(f"res: {len(res)}")
|
||||
if ignore_growing is True:
|
||||
assert len(res) == 0
|
||||
else:
|
||||
assert len(res) == limit
|
||||
|
||||
|
||||
|
||||
@pytest.mark.L1
|
||||
class TestSearchVectorNegative(TestBase):
|
||||
|
|
|
@ -10,7 +10,7 @@ import base64
|
|||
import requests
|
||||
from loguru import logger
|
||||
import datetime
|
||||
|
||||
from sklearn.metrics import pairwise_distances
|
||||
fake = Faker()
|
||||
rng = np.random.default_rng()
|
||||
|
||||
|
@ -240,4 +240,28 @@ def get_all_fields_by_data(data, exclude_fields=None):
|
|||
return list(fields)
|
||||
|
||||
|
||||
def ip_distance(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
|
||||
def cosine_distance(u, v, epsilon=1e-8):
|
||||
dot_product = np.dot(u, v)
|
||||
norm_u = np.linalg.norm(u)
|
||||
norm_v = np.linalg.norm(v)
|
||||
return dot_product / (max(norm_u * norm_v, epsilon))
|
||||
|
||||
|
||||
def l2_distance(u, v):
|
||||
return np.sum((u - v) ** 2)
|
||||
|
||||
|
||||
def get_sorted_distance(train_emb, test_emb, metric_type):
|
||||
milvus_sklearn_metric_map = {
|
||||
"L2": l2_distance,
|
||||
"COSINE": cosine_distance,
|
||||
"IP": ip_distance
|
||||
}
|
||||
distance = pairwise_distances(train_emb, Y=test_emb, metric=milvus_sklearn_metric_map[metric_type], n_jobs=-1)
|
||||
distance = np.array(distance.T, order='C', dtype=np.float16)
|
||||
distance_sorted = np.sort(distance, axis=1).tolist()
|
||||
return distance_sorted
|
||||
|
|
Loading…
Reference in New Issue