fix: [restful]use default search parameter `nq: 0` (#32362)

issue: #32225
master pr: #32355 #32485
v2.4 pr: #32356 #32486

1. v1 can only accept one vector, but v2 accept list of vectors
2. cannot get dbName from AliasReq #31978
3. parameters of create collection #31176

---------

Signed-off-by: PowderLi <min.li@zilliz.com>
pull/32551/head
PowderLi 2024-04-23 17:42:19 +08:00 committed by GitHub
parent 0f118e7083
commit 2d5f674c78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 146 additions and 29 deletions

View File

@ -240,6 +240,10 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Download Cpp code coverage results
uses: actions/download-artifact@v3.0.1
with:

View File

@ -2,6 +2,7 @@ from conans import ConanFile
class MilvusConan(ConanFile):
keep_imports = True
settings = "os", "compiler", "build_type", "arch"
requires = (
"rocksdb/6.29.5@milvus/dev",

View File

@ -78,6 +78,7 @@ const (
DefaultAliasName = "the_alias"
DefaultOutputFields = "*"
HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64"
HTTPHeaderDBName = "DB-Name"
HTTPHeaderRequestTimeout = "Request-Timeout"
HTTPDefaultTimeout = 30 * time.Second
HTTPReturnCode = "code"

View File

@ -158,7 +158,10 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc {
dbName = getter.GetDbName()
}
if dbName == "" {
dbName = DefaultDbName
dbName = c.Request.Header.Get(HTTPHeaderDBName)
if dbName == "" {
dbName = DefaultDbName
}
}
username, _ := c.Get(ContextUsername)
ctx, span := otel.Tracer(typeutil.ProxyRole).Start(context.Background(), c.Request.URL.Path)
@ -867,7 +870,6 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
GuaranteeTimestamp: BoundedTimestamp,
Nq: int64(1),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
@ -917,7 +919,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
idDataType = schemapb.DataType_VarChar
idParams = append(idParams, &commonpb.KeyValuePair{
Key: common.MaxLengthKey,
Value: httpReq.Params["max_length"],
Value: fmt.Sprintf("%v", httpReq.Params["max_length"]),
})
httpReq.IDType = "VarChar"
case "", "Int64", "int64":
@ -939,8 +941,10 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
httpReq.VectorFieldName = DefaultVectorFieldName
}
enableDynamic := EnableDynamic
if en, err := strconv.ParseBool(httpReq.Params["enableDynamicField"]); err == nil {
enableDynamic = en
if enStr, ok := httpReq.Params["enableDynamicField"]; ok {
if en, err := strconv.ParseBool(fmt.Sprintf("%v", enStr)); err == nil {
enableDynamic = en
}
}
schema, err = proto.Marshal(&schemapb.CollectionSchema{
Name: httpReq.CollectionName,
@ -1010,8 +1014,10 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
}
if field.IsPartitionKey {
partitionsNum = int64(64)
if partitions, err := strconv.ParseInt(httpReq.Params["partitionsNum"], 10, 64); err == nil {
partitionsNum = partitions
if partitionsNumStr, ok := httpReq.Params["partitionsNum"]; ok {
if partitions, err := strconv.ParseInt(fmt.Sprintf("%v", partitionsNumStr), 10, 64); err == nil {
partitionsNum = partitions
}
}
}
for key, fieldParam := range field.ElementTypeParams {
@ -1031,14 +1037,16 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
return nil, err
}
shardsNum := int32(ShardNumDefault)
if shards, err := strconv.ParseInt(httpReq.Params["shardsNum"], 10, 64); err == nil {
shardsNum = int32(shards)
if shardsNumStr, ok := httpReq.Params["shardsNum"]; ok {
if shards, err := strconv.ParseInt(fmt.Sprintf("%v", shardsNumStr), 10, 64); err == nil {
shardsNum = int32(shards)
}
}
consistencyLevel := commonpb.ConsistencyLevel_Bounded
if level, ok := commonpb.ConsistencyLevel_value[httpReq.Params["consistencyLevel"]]; ok {
consistencyLevel = commonpb.ConsistencyLevel(level)
} else {
if len(httpReq.Params["consistencyLevel"]) > 0 {
if _, ok := httpReq.Params["consistencyLevel"]; ok {
if level, ok := commonpb.ConsistencyLevel_value[fmt.Sprintf("%s", httpReq.Params["consistencyLevel"])]; ok {
consistencyLevel = commonpb.ConsistencyLevel(level)
} else {
err := merr.WrapErrParameterInvalid("Strong, Session, Bounded, Eventually, Customized", httpReq.Params["consistencyLevel"],
"consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded")
log.Ctx(ctx).Warn("high level restful api, create collection fail", zap.Error(err), zap.Any("request", anyReq))
@ -1063,7 +1071,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
if _, ok := httpReq.Params["ttlSeconds"]; ok {
req.Properties = append(req.Properties, &commonpb.KeyValuePair{
Key: common.CollectionTTLConfigKey,
Value: httpReq.Params["ttlSeconds"],
Value: fmt.Sprintf("%v", httpReq.Params["ttlSeconds"]),
})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {

View File

@ -423,14 +423,52 @@ func TestDatabaseWrapper(t *testing.T) {
}
})
}
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "default"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "test"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once()
rawTestCases := []rawTestCase{
{
errMsg: "database not found, database: test",
errCode: 800, // ErrDatabaseNotFound
},
{},
{
errMsg: "",
errCode: 65535,
},
}
for _, testcase := range rawTestCases {
t.Run("post with db"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader([]byte(`{}`)))
req.Header.Set(HTTPHeaderDBName, "test")
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
if testcase.errCode != 0 {
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}
func TestCreateCollection(t *testing.T) {
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(9)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4)
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(4)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(11)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
testEngine := initHTTPServerV2(mp, false)
@ -450,7 +488,17 @@ func TestCreateCollection(t *testing.T) {
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
`"params": {"max_length": "256", "enableDynamicField": "false", "shardsNum": "2", "consistencyLevel": "unknown", "ttlSeconds": "3600"}}`),
`"params": {"max_length": "256", "enableDynamicField": false, "shardsNum": "2", "consistencyLevel": "Strong", "ttlSeconds": "3600"}}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
`"params": {"max_length": 256, "enableDynamicField": false, "shardsNum": 2, "consistencyLevel": "Strong", "ttlSeconds": 3600}}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
`"params": {"max_length": 256, "enableDynamicField": false, "shardsNum": 2, "consistencyLevel": "unknown", "ttlSeconds": 3600}}`),
errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded: invalid parameter[expected=Strong, Session, Bounded, Eventually, Customized][actual=unknown]",
errCode: 1100, // ErrParameterInvalid
})

View File

@ -250,17 +250,17 @@ type CollectionSchema struct {
}
type CollectionReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Dimension int32 `json:"dimension"`
IDType string `json:"idType"`
AutoID bool `json:"autoID"`
MetricType string `json:"metricType"`
PrimaryFieldName string `json:"primaryFieldName"`
VectorFieldName string `json:"vectorFieldName"`
Schema CollectionSchema `json:"schema"`
IndexParams []IndexParam `json:"indexParams"`
Params map[string]string `json:"params"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Dimension int32 `json:"dimension"`
IDType string `json:"idType"`
AutoID bool `json:"autoID"`
MetricType string `json:"metricType"`
PrimaryFieldName string `json:"primaryFieldName"`
VectorFieldName string `json:"vectorFieldName"`
Schema CollectionSchema `json:"schema"`
IndexParams []IndexParam `json:"indexParams"`
Params map[string]interface{} `json:"params"`
}
func (req *CollectionReq) GetDbName() string { return req.DbName }

View File

@ -83,6 +83,61 @@ class TestCreateCollection(TestBase):
for index in rsp['data']['indexes']:
assert index['metricType'] == metric_type
@pytest.mark.parametrize("enable_dynamic_field", [False, "False", "0"])
@pytest.mark.parametrize("request_shards_num", [2, "2"])
@pytest.mark.parametrize("request_ttl_seconds", [360, "360"])
def test_create_collections_without_params(self, enable_dynamic_field, request_shards_num, request_ttl_seconds):
"""
target: test create collection
method: create a collection with a simple schema
expected: create collection success
"""
name = gen_collection_name()
dim = 128
metric_type = "COSINE"
client = self.collection_client
num_shards = 2
consistency_level = "Strong"
ttl_seconds = 360
payload = {
"collectionName": name,
"dimension": dim,
"metricType": metric_type,
"params":{
"enableDynamicField": enable_dynamic_field,
"shardsNum": request_shards_num,
"consistencyLevel": f"{consistency_level}",
"ttlSeconds": request_ttl_seconds,
},
}
logging.info(f"create collection {name} with payload: {payload}")
rsp = client.collection_create(payload)
assert rsp['code'] == 200
rsp = client.collection_list()
all_collections = rsp['data']
assert name in all_collections
# describe collection by pymilvus
c = Collection(name)
res = c.describe()
logger.info(f"describe collection: {res}")
# describe collection
time.sleep(10)
rsp = client.collection_describe(name)
logger.info(f"describe collection: {rsp}")
ttl_seconds_actual = None
for d in rsp["data"]["properties"]:
if d["key"] == "collection.ttl.seconds":
ttl_seconds_actual = int(d["value"])
assert rsp['code'] == 200
assert rsp['data']['enableDynamicField'] == False
assert rsp['data']['collectionName'] == name
assert rsp['data']['shardsNum'] == num_shards
assert rsp['data']['consistencyLevel'] == consistency_level
assert ttl_seconds_actual == ttl_seconds
def test_create_collections_with_all_params(self):
"""
target: test create collection