fix: [restful]use default search parameter `nq: 0` (#32362)

issue: #32225
master pr: #32355 #32485
v2.4 pr: #32356 #32486

1. v1 can only accept one vector, but v2 accept list of vectors
2. cannot get dbName from AliasReq #31978
3. parameters of create collection #31176

---------

Signed-off-by: PowderLi <min.li@zilliz.com>
pull/32551/head
PowderLi 2024-04-23 17:42:19 +08:00 committed by GitHub
parent 0f118e7083
commit 2d5f674c78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 146 additions and 29 deletions

View File

@ -240,6 +240,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 5 timeout-minutes: 5
steps: steps:
- name: Checkout
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Download Cpp code coverage results - name: Download Cpp code coverage results
uses: actions/download-artifact@v3.0.1 uses: actions/download-artifact@v3.0.1
with: with:

View File

@ -2,6 +2,7 @@ from conans import ConanFile
class MilvusConan(ConanFile): class MilvusConan(ConanFile):
keep_imports = True
settings = "os", "compiler", "build_type", "arch" settings = "os", "compiler", "build_type", "arch"
requires = ( requires = (
"rocksdb/6.29.5@milvus/dev", "rocksdb/6.29.5@milvus/dev",

View File

@ -78,6 +78,7 @@ const (
DefaultAliasName = "the_alias" DefaultAliasName = "the_alias"
DefaultOutputFields = "*" DefaultOutputFields = "*"
HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64" HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64"
HTTPHeaderDBName = "DB-Name"
HTTPHeaderRequestTimeout = "Request-Timeout" HTTPHeaderRequestTimeout = "Request-Timeout"
HTTPDefaultTimeout = 30 * time.Second HTTPDefaultTimeout = 30 * time.Second
HTTPReturnCode = "code" HTTPReturnCode = "code"

View File

@ -158,7 +158,10 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc {
dbName = getter.GetDbName() dbName = getter.GetDbName()
} }
if dbName == "" { if dbName == "" {
dbName = DefaultDbName dbName = c.Request.Header.Get(HTTPHeaderDBName)
if dbName == "" {
dbName = DefaultDbName
}
} }
username, _ := c.Get(ContextUsername) username, _ := c.Get(ContextUsername)
ctx, span := otel.Tracer(typeutil.ProxyRole).Start(context.Background(), c.Request.URL.Path) ctx, span := otel.Tracer(typeutil.ProxyRole).Start(context.Background(), c.Request.URL.Path)
@ -867,7 +870,6 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
PartitionNames: httpReq.PartitionNames, PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams, SearchParams: searchParams,
GuaranteeTimestamp: BoundedTimestamp, GuaranteeTimestamp: BoundedTimestamp,
Nq: int64(1),
} }
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest)) return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
@ -917,7 +919,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
idDataType = schemapb.DataType_VarChar idDataType = schemapb.DataType_VarChar
idParams = append(idParams, &commonpb.KeyValuePair{ idParams = append(idParams, &commonpb.KeyValuePair{
Key: common.MaxLengthKey, Key: common.MaxLengthKey,
Value: httpReq.Params["max_length"], Value: fmt.Sprintf("%v", httpReq.Params["max_length"]),
}) })
httpReq.IDType = "VarChar" httpReq.IDType = "VarChar"
case "", "Int64", "int64": case "", "Int64", "int64":
@ -939,8 +941,10 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
httpReq.VectorFieldName = DefaultVectorFieldName httpReq.VectorFieldName = DefaultVectorFieldName
} }
enableDynamic := EnableDynamic enableDynamic := EnableDynamic
if en, err := strconv.ParseBool(httpReq.Params["enableDynamicField"]); err == nil { if enStr, ok := httpReq.Params["enableDynamicField"]; ok {
enableDynamic = en if en, err := strconv.ParseBool(fmt.Sprintf("%v", enStr)); err == nil {
enableDynamic = en
}
} }
schema, err = proto.Marshal(&schemapb.CollectionSchema{ schema, err = proto.Marshal(&schemapb.CollectionSchema{
Name: httpReq.CollectionName, Name: httpReq.CollectionName,
@ -1010,8 +1014,10 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
} }
if field.IsPartitionKey { if field.IsPartitionKey {
partitionsNum = int64(64) partitionsNum = int64(64)
if partitions, err := strconv.ParseInt(httpReq.Params["partitionsNum"], 10, 64); err == nil { if partitionsNumStr, ok := httpReq.Params["partitionsNum"]; ok {
partitionsNum = partitions if partitions, err := strconv.ParseInt(fmt.Sprintf("%v", partitionsNumStr), 10, 64); err == nil {
partitionsNum = partitions
}
} }
} }
for key, fieldParam := range field.ElementTypeParams { for key, fieldParam := range field.ElementTypeParams {
@ -1031,14 +1037,16 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
return nil, err return nil, err
} }
shardsNum := int32(ShardNumDefault) shardsNum := int32(ShardNumDefault)
if shards, err := strconv.ParseInt(httpReq.Params["shardsNum"], 10, 64); err == nil { if shardsNumStr, ok := httpReq.Params["shardsNum"]; ok {
shardsNum = int32(shards) if shards, err := strconv.ParseInt(fmt.Sprintf("%v", shardsNumStr), 10, 64); err == nil {
shardsNum = int32(shards)
}
} }
consistencyLevel := commonpb.ConsistencyLevel_Bounded consistencyLevel := commonpb.ConsistencyLevel_Bounded
if level, ok := commonpb.ConsistencyLevel_value[httpReq.Params["consistencyLevel"]]; ok { if _, ok := httpReq.Params["consistencyLevel"]; ok {
consistencyLevel = commonpb.ConsistencyLevel(level) if level, ok := commonpb.ConsistencyLevel_value[fmt.Sprintf("%s", httpReq.Params["consistencyLevel"])]; ok {
} else { consistencyLevel = commonpb.ConsistencyLevel(level)
if len(httpReq.Params["consistencyLevel"]) > 0 { } else {
err := merr.WrapErrParameterInvalid("Strong, Session, Bounded, Eventually, Customized", httpReq.Params["consistencyLevel"], err := merr.WrapErrParameterInvalid("Strong, Session, Bounded, Eventually, Customized", httpReq.Params["consistencyLevel"],
"consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded") "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded")
log.Ctx(ctx).Warn("high level restful api, create collection fail", zap.Error(err), zap.Any("request", anyReq)) log.Ctx(ctx).Warn("high level restful api, create collection fail", zap.Error(err), zap.Any("request", anyReq))
@ -1063,7 +1071,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
if _, ok := httpReq.Params["ttlSeconds"]; ok { if _, ok := httpReq.Params["ttlSeconds"]; ok {
req.Properties = append(req.Properties, &commonpb.KeyValuePair{ req.Properties = append(req.Properties, &commonpb.KeyValuePair{
Key: common.CollectionTTLConfigKey, Key: common.CollectionTTLConfigKey,
Value: httpReq.Params["ttlSeconds"], Value: fmt.Sprintf("%v", httpReq.Params["ttlSeconds"]),
}) })
} }
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {

View File

@ -423,14 +423,52 @@ func TestDatabaseWrapper(t *testing.T) {
} }
}) })
} }
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "default"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "test"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once()
rawTestCases := []rawTestCase{
{
errMsg: "database not found, database: test",
errCode: 800, // ErrDatabaseNotFound
},
{},
{
errMsg: "",
errCode: 65535,
},
}
for _, testcase := range rawTestCases {
t.Run("post with db"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader([]byte(`{}`)))
req.Header.Set(HTTPHeaderDBName, "test")
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
if testcase.errCode != 0 {
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
} }
func TestCreateCollection(t *testing.T) { func TestCreateCollection(t *testing.T) {
postTestCases := []requestBodyTestCase{} postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t) mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(9) mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(11)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4) mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4) mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice() mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once() mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
testEngine := initHTTPServerV2(mp, false) testEngine := initHTTPServerV2(mp, false)
@ -450,7 +488,17 @@ func TestCreateCollection(t *testing.T) {
postTestCases = append(postTestCases, requestBodyTestCase{ postTestCases = append(postTestCases, requestBodyTestCase{
path: path, path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
`"params": {"max_length": "256", "enableDynamicField": "false", "shardsNum": "2", "consistencyLevel": "unknown", "ttlSeconds": "3600"}}`), `"params": {"max_length": "256", "enableDynamicField": false, "shardsNum": "2", "consistencyLevel": "Strong", "ttlSeconds": "3600"}}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
`"params": {"max_length": 256, "enableDynamicField": false, "shardsNum": 2, "consistencyLevel": "Strong", "ttlSeconds": 3600}}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
`"params": {"max_length": 256, "enableDynamicField": false, "shardsNum": 2, "consistencyLevel": "unknown", "ttlSeconds": 3600}}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded: invalid parameter[expected=Strong, Session, Bounded, Eventually, Customized][actual=unknown]", errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded: invalid parameter[expected=Strong, Session, Bounded, Eventually, Customized][actual=unknown]",
errCode: 1100, // ErrParameterInvalid errCode: 1100, // ErrParameterInvalid
}) })

View File

@ -250,17 +250,17 @@ type CollectionSchema struct {
} }
type CollectionReq struct { type CollectionReq struct {
DbName string `json:"dbName"` DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"` CollectionName string `json:"collectionName" binding:"required"`
Dimension int32 `json:"dimension"` Dimension int32 `json:"dimension"`
IDType string `json:"idType"` IDType string `json:"idType"`
AutoID bool `json:"autoID"` AutoID bool `json:"autoID"`
MetricType string `json:"metricType"` MetricType string `json:"metricType"`
PrimaryFieldName string `json:"primaryFieldName"` PrimaryFieldName string `json:"primaryFieldName"`
VectorFieldName string `json:"vectorFieldName"` VectorFieldName string `json:"vectorFieldName"`
Schema CollectionSchema `json:"schema"` Schema CollectionSchema `json:"schema"`
IndexParams []IndexParam `json:"indexParams"` IndexParams []IndexParam `json:"indexParams"`
Params map[string]string `json:"params"` Params map[string]interface{} `json:"params"`
} }
func (req *CollectionReq) GetDbName() string { return req.DbName } func (req *CollectionReq) GetDbName() string { return req.DbName }

View File

@ -83,6 +83,61 @@ class TestCreateCollection(TestBase):
for index in rsp['data']['indexes']: for index in rsp['data']['indexes']:
assert index['metricType'] == metric_type assert index['metricType'] == metric_type
@pytest.mark.parametrize("enable_dynamic_field", [False, "False", "0"])
@pytest.mark.parametrize("request_shards_num", [2, "2"])
@pytest.mark.parametrize("request_ttl_seconds", [360, "360"])
def test_create_collections_without_params(self, enable_dynamic_field, request_shards_num, request_ttl_seconds):
"""
target: test create collection
method: create a collection with a simple schema
expected: create collection success
"""
name = gen_collection_name()
dim = 128
metric_type = "COSINE"
client = self.collection_client
num_shards = 2
consistency_level = "Strong"
ttl_seconds = 360
payload = {
"collectionName": name,
"dimension": dim,
"metricType": metric_type,
"params":{
"enableDynamicField": enable_dynamic_field,
"shardsNum": request_shards_num,
"consistencyLevel": f"{consistency_level}",
"ttlSeconds": request_ttl_seconds,
},
}
logging.info(f"create collection {name} with payload: {payload}")
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection by pymilvus
c = Collection(name)
res = c.describe()
logger.info(f"describe collection: {res}")
# describe collection
time.sleep(10)
rsp = client.collection_describe(name)
logger.info(f"describe collection: {rsp}")
ttl_seconds_actual = None
for d in rsp["data"]["properties"]:
if d["key"] == "collection.ttl.seconds":
ttl_seconds_actual = int(d["value"])
assert rsp['code'] == 200
assert rsp['data']['enableDynamicField'] == False
assert rsp['data']['collectionName'] == name
assert rsp['data']['shardsNum'] == num_shards
assert rsp['data']['consistencyLevel'] == consistency_level
assert ttl_seconds_actual == ttl_seconds
def test_create_collections_with_all_params(self): def test_create_collections_with_all_params(self):
""" """
target: test create collection target: test create collection