diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index c66cb0e5d1..fce1980004 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -68,6 +68,9 @@ const ( AlterAction = "alter" AlterPropertiesAction = "alter_properties" DropPropertiesAction = "drop_properties" + CompactAction = "compact" + CompactionStateAction = "get_compaction_state" + FlushAction = "flush" GetProgressAction = "get_progress" // deprecated, keep it for compatibility, use `/v2/vectordb/jobs/import/describe` instead AddPrivilegesToGroupAction = "add_privileges_to_group" RemovePrivilegesFromGroupAction = "remove_privileges_from_group" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index be64935412..c9f8c0f44c 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -82,14 +82,19 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection)))) router.POST(CollectionCategory+AlterPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReqWithProperties{} }, wrapperTraceLog(h.alterCollectionProperties)))) router.POST(CollectionCategory+DropPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &DropCollectionPropertiesReq{} }, wrapperTraceLog(h.dropCollectionProperties)))) + router.POST(CollectionCategory+CompactAction, timeoutMiddleware(wrapperPost(func() any { return &CompactReq{} }, wrapperTraceLog(h.compact)))) + router.POST(CollectionCategory+CompactionStateAction, timeoutMiddleware(wrapperPost(func() any { return &GetCompactionStateReq{} }, wrapperTraceLog(h.getcompactionState)))) + router.POST(CollectionCategory+FlushAction, timeoutMiddleware(wrapperPost(func() any { return &FlushReq{} }, wrapperTraceLog(h.flush)))) router.POST(CollectionFieldCategory+AlterPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionFieldReqWithParams{} }, wrapperTraceLog(h.alterCollectionFieldProperties)))) router.POST(DataBaseCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.createDatabase)))) router.POST(DataBaseCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqRequiredName{} }, wrapperTraceLog(h.dropDatabase)))) + router.POST(DataBaseCategory+DropPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &DropDatabasePropertiesReq{} }, wrapperTraceLog(h.dropDatabaseProperties)))) router.POST(DataBaseCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listDatabases)))) router.POST(DataBaseCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqRequiredName{} }, wrapperTraceLog(h.describeDatabase)))) router.POST(DataBaseCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase)))) + router.POST(DataBaseCategory+AlterPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase)))) // Query router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &QueryReqV2{ @@ -686,7 +691,7 @@ func (h *HandlersV2) dropCollectionProperties(ctx context.Context, c *gin.Contex req := &milvuspb.AlterCollectionRequest{ DbName: dbName, CollectionName: httpReq.CollectionName, - DeleteKeys: httpReq.DeleteKeys, + DeleteKeys: httpReq.PropertyKeys, } c.Set(ContextRequest, req) resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { @@ -698,6 +703,62 @@ func (h *HandlersV2) dropCollectionProperties(ctx context.Context, c *gin.Contex return resp, err } +func (h *HandlersV2) compact(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*CompactReq) + req := &milvuspb.ManualCompactionRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + MajorCompaction: httpReq.IsClustering, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ManualCompaction", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.ManualCompaction(reqCtx, req.(*milvuspb.ManualCompactionRequest)) + }) + if err == nil { + resp := resp.(*milvuspb.ManualCompactionResponse) + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: gin.H{"compactionID": resp.CompactionID}, + }) + } + return resp, err +} + +func (h *HandlersV2) getcompactionState(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*GetCompactionStateReq) + req := &milvuspb.GetCompactionStateRequest{ + CompactionID: httpReq.JobID, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetCompactionState", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.GetCompactionState(reqCtx, req.(*milvuspb.GetCompactionStateRequest)) + }) + if err == nil { + resp := resp.(*milvuspb.GetCompactionStateResponse) + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: gin.H{"compactionID": httpReq.JobID, "state": resp.State.String(), "executingPlanNumber": resp.ExecutingPlanNo, "timeoutPlanNumber": resp.TimeoutPlanNo, "completedPlanNumber": resp.CompletedPlanNo}, + }) + } + return resp, err +} + +func (h *HandlersV2) flush(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*FlushReq) + req := &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{httpReq.CollectionName}, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Flush", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.Flush(reqCtx, req.(*milvuspb.FlushRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + func (h *HandlersV2) alterCollectionFieldProperties(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*CollectionFieldReqWithParams) req := &milvuspb.AlterCollectionFieldRequest{ @@ -1058,20 +1119,6 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche }) } -func generateSearchParams(reqSearchParams searchParams) []*commonpb.KeyValuePair { - var searchParams []*commonpb.KeyValuePair - if reqSearchParams.Params == nil { - reqSearchParams.Params = make(map[string]any) - } - bs, _ := json.Marshal(reqSearchParams.Params) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.HintsKey, Value: reqSearchParams.Hints}) - // need to exposure ParamRoundDecimal in req? - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) - return searchParams -} - func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*SearchReqV2) req := &milvuspb.SearchRequest{ @@ -1100,7 +1147,15 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN return nil, err } - searchParams := generateSearchParams(httpReq.SearchParams) + searchParams, err := generateSearchParams(httpReq.SearchParams) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, generate SearchParams failed", zap.Error(err)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error(), + }) + return nil, err + } searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) if httpReq.GroupByField != "" { @@ -1181,7 +1236,15 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq body, _ := c.Get(gin.BodyBytesKey) searchArray := gjson.Get(string(body.([]byte)), "search").Array() for i, subReq := range httpReq.Search { - searchParams := generateSearchParams(subReq.SearchParams) + searchParams, err := generateSearchParams(subReq.SearchParams) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, generate SearchParams failed", zap.Error(err)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error(), + }) + return nil, err + } searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField}) @@ -1613,6 +1676,22 @@ func (h *HandlersV2) dropDatabase(ctx context.Context, c *gin.Context, anyReq an return resp, err } +func (h *HandlersV2) dropDatabaseProperties(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*DropDatabasePropertiesReq) + req := &milvuspb.AlterDatabaseRequest{ + DbName: dbName, + DeleteKeys: httpReq.PropertyKeys, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.AlterDatabase(reqCtx, req.(*milvuspb.AlterDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + // todo: use a more flexible way to handle the number of input parameters of req func (h *HandlersV2) listDatabases(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { req := &milvuspb.ListDatabasesRequest{} @@ -2267,7 +2346,7 @@ func (h *HandlersV2) dropIndexProperties(ctx context.Context, c *gin.Context, an req := &milvuspb.AlterIndexRequest{ DbName: dbName, CollectionName: httpReq.CollectionName, - DeleteKeys: httpReq.DeleteKeys, + DeleteKeys: httpReq.PropertyKeys, } c.Set(ContextRequest, req) resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 891284f62b..2cbdb22131 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -673,6 +673,126 @@ func TestCreateIndex(t *testing.T) { } } +func TestCompact(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + postTestCases := []requestBodyTestCase{} + mp := mocks.NewMockProxy(t) + mp.EXPECT().ManualCompaction(mock.Anything, mock.Anything).Return(&milvuspb.ManualCompactionResponse{CompactionID: 1}, nil).Once() + mp.EXPECT().ManualCompaction(mock.Anything, mock.Anything).Return( + &milvuspb.ManualCompactionResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(CollectionCategory, CompactAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName":"invalid_name"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().GetCompactionState(mock.Anything, mock.Anything).Return(&milvuspb.GetCompactionStateResponse{}, nil).Once() + mp.EXPECT().GetCompactionState(mock.Anything, mock.Anything).Return( + &milvuspb.GetCompactionStateResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + path = versionalV2(CollectionCategory, CompactionStateAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName":"invalid_name"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + for _, testcase := range postTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + fmt.Println(w.Body.String()) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + }) + } +} + +func TestFlush(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + postTestCases := []requestBodyTestCase{} + mp := mocks.NewMockProxy(t) + mp.EXPECT().Flush(mock.Anything, mock.Anything).Return(&milvuspb.FlushResponse{}, nil).Once() + mp.EXPECT().Flush(mock.Anything, mock.Anything).Return( + &milvuspb.FlushResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(CollectionCategory, FlushAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName":"invalid_name"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + for _, testcase := range postTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + fmt.Println(w.Body.String()) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + }) + } +} + func TestDatabase(t *testing.T) { paramtable.Init() // disable rate limit @@ -784,6 +904,46 @@ func TestDatabase(t *testing.T) { errCode: 1100, // ErrParameterInvalid }) + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, DropPropertiesAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, AlterPropertiesAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + for _, testcase := range postTestCases { t.Run("post"+testcase.path, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) @@ -841,12 +1001,12 @@ func TestColletcionProperties(t *testing.T) { // success postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName":"test", "deleteKeys":["mmap"]}`), + requestBody: []byte(`{"collectionName":"test", "propertyKeys":["mmap"]}`), }) // mock fail postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName":"mock", "deleteKeys":["mmap"]}`), + requestBody: []byte(`{"collectionName":"mock", "propertyKeys":["mmap"]}`), errMsg: "mock", errCode: 1100, // ErrParameterInvalid }) @@ -908,12 +1068,12 @@ func TestIndexProperties(t *testing.T) { // success postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName":"test","indexName":"test", "deleteKeys":["test"]}`), + requestBody: []byte(`{"collectionName":"test","indexName":"test", "propertyKeys":["test"]}`), }) // mock fail postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName":"mock","indexName":"test", "deleteKeys":["test"]}`), + requestBody: []byte(`{"collectionName":"mock","indexName":"test", "propertyKeys":["test"]}`), errMsg: "mock", errCode: 1100, // ErrParameterInvalid }) @@ -2205,7 +2365,7 @@ func TestSearchV2(t *testing.T) { Schema: collSchema, ShardsNum: ShardNumDefault, Status: &StatusSuccess, - }, nil).Times(14) + }, nil).Times(15) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3) mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ @@ -2227,12 +2387,6 @@ func TestSearchV2(t *testing.T) { path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`), }) - queryTestCases = append(queryTestCases, requestBodyTestCase{ - path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`), - errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool", - errCode: 1801, // ErrIncorrectParameterFormat - }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`), @@ -2410,8 +2564,8 @@ func TestSearchV2(t *testing.T) { queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`), - errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}", - errCode: 1801, // ErrIncorrectParameterFormat + errMsg: "searchParams.params must be a dict: invalid parameter", + errCode: 1100, // ErrParameterInvalid }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 409c8fb799..d1da6db93a 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -48,6 +48,13 @@ type DatabaseReqWithProperties struct { func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName } +type DropDatabasePropertiesReq struct { + DbName string `json:"dbName" binding:"required"` + PropertyKeys []string `json:"propertyKeys"` +} + +func (req *DatabaseReqWithProperties) DropDatabasPropertiesReq() string { return req.DbName } + type CollectionNameReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` @@ -103,7 +110,7 @@ func (req *RenameCollectionReq) GetDbName() string { return req.DbName } type DropCollectionPropertiesReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` - DeleteKeys []string `json:"deleteKeys"` + PropertyKeys []string `json:"propertyKeys"` } func (req *DropCollectionPropertiesReq) GetDbName() string { return req.DbName } @@ -112,6 +119,33 @@ func (req *DropCollectionPropertiesReq) GetCollectionName() string { return req.CollectionName } +type CompactReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + IsClustering bool `json:"isClustering"` +} + +func (req *CompactReq) GetDbName() string { return req.DbName } + +func (req *CompactReq) GetCollectionName() string { + return req.CollectionName +} + +type GetCompactionStateReq struct { + JobID int64 `json:"jobID"` +} + +type FlushReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` +} + +func (req *FlushReq) GetDbName() string { return req.DbName } + +func (req *FlushReq) GetCollectionName() string { + return req.CollectionName +} + type CollectionFieldReqWithParams struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` @@ -217,14 +251,6 @@ type CollectionDataReq struct { func (req *CollectionDataReq) GetDbName() string { return req.DbName } -type searchParams struct { - // not use metricType any more, just for compatibility - MetricType string `json:"metricType"` - Params map[string]interface{} `json:"params"` - IgnoreGrowing bool `json:"ignoreGrowing"` - Hints string `json:"hints"` -} - type SearchReqV2 struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` @@ -238,7 +264,7 @@ type SearchReqV2 struct { Limit int32 `json:"limit"` Offset int32 `json:"offset"` OutputFields []string `json:"outputFields"` - SearchParams searchParams `json:"searchParams"` + SearchParams map[string]interface{} `json:"searchParams"` ConsistencyLevel string `json:"consistencyLevel"` ExprParams map[string]interface{} `json:"exprParams"` // not use Params any more, just for compatibility @@ -260,7 +286,7 @@ type SubSearchReq struct { MetricType string `json:"metricType"` Limit int32 `json:"limit"` Offset int32 `json:"offset"` - SearchParams searchParams `json:"params"` + SearchParams map[string]interface{} `json:"params"` ExprParams map[string]interface{} `json:"exprParams"` } @@ -426,7 +452,7 @@ type DropIndexPropertiesReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` IndexName string `json:"indexName" binding:"required"` - DeleteKeys []string `json:"deleteKeys"` + PropertyKeys []string `json:"propertyKeys"` } func (req *DropIndexPropertiesReq) GetDbName() string { return req.DbName } diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 9ed396d220..f7276679b9 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1853,3 +1853,64 @@ func WrapErrorToResponse(err error) *milvuspb.BoolResponse { Status: merr.Status(err), } } + +// after 2.5.2, all parameters of search_params can be written into one layer +// no more parameters will be written searchParams.params +// to ensure compatibility and milvus can still get a json format parameter +// try to write all the parameters under searchParams into searchParams.Params +func generateSearchParams(reqSearchParams map[string]interface{}) ([]*commonpb.KeyValuePair, error) { + var searchParams []*commonpb.KeyValuePair + var params interface{} + if val, ok := reqSearchParams[Params]; ok { + params = val + } + + paramsMap := make(map[string]interface{}) + if params != nil { + var ok bool + if paramsMap, ok = params.(map[string]interface{}); !ok { + return nil, merr.WrapErrParameterInvalidMsg("searchParams.params must be a dict") + } + } + + deepEqual := func(value1, value2 interface{}) bool { + // try to handle 10.0==10 + switch v1 := value1.(type) { + case float64: + if v2, ok := value2.(int); ok { + return v1 == float64(v2) + } + case int: + if v2, ok := value2.(float64); ok { + return float64(v1) == v2 + } + } + return reflect.DeepEqual(value1, value2) + } + + for key, value := range reqSearchParams { + if val, ok := paramsMap[key]; ok { + if !deepEqual(val, value) { + return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("ambiguous parameter: %s, in search_param: %v, in search_param.params: %v", key, value, val)) + } + } else if key != Params { + paramsMap[key] = value + } + } + + bs, _ := json.Marshal(paramsMap) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)}) + + for key, value := range reqSearchParams { + if key != Params { + // for compatibility + if key == "ignoreGrowing" { + key = common.IgnoreGrowing + } + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + } + // need to exposure ParamRoundDecimal in req? + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) + return searchParams, nil +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index 78004ebaea..9ac84c07e2 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -2377,3 +2377,87 @@ func TestGenerateExpressionTemplate(t *testing.T) { assert.Equal(t, actual, ans[i]) } } + +func TestGenerateSearchParams(t *testing.T) { + t.Run("searchParams.params must be a dict", func(t *testing.T) { + reqSearchParams := map[string]interface{}{"params": 0} + _, err := generateSearchParams(reqSearchParams) + assert.NotNil(t, err) + }) + + t.Run("ambiguous parameter", func(t *testing.T) { + reqSearchParams := map[string]interface{}{"radius": 100, "params": map[string]interface{}{"radius": 10}} + _, err := generateSearchParams(reqSearchParams) + assert.NotNil(t, err) + }) + + t.Run("no ambiguous parameter", func(t *testing.T) { + reqSearchParams := map[string]interface{}{"radius": 10, "params": map[string]interface{}{"radius": 10.0}} + _, err := generateSearchParams(reqSearchParams) + assert.Nil(t, err) + + reqSearchParams = map[string]interface{}{"radius": 10.0, "params": map[string]interface{}{"radius": 10}} + _, err = generateSearchParams(reqSearchParams) + assert.Nil(t, err) + + reqSearchParams = map[string]interface{}{"radius": 10, "params": map[string]interface{}{"radius": 10}} + searchParams, err := generateSearchParams(reqSearchParams) + assert.Equal(t, 3, len(searchParams)) + assert.Nil(t, err) + for _, kvs := range searchParams { + if kvs.Key == "radius" { + assert.Equal(t, "10", kvs.Value) + } + if kvs.Key == "params" { + var paramsMap map[string]interface{} + err := json.Unmarshal([]byte(kvs.Value), ¶msMap) + assert.Nil(t, err) + assert.Equal(t, 1, len(paramsMap)) + assert.Equal(t, paramsMap["radius"], float64(10)) + } + } + }) + + t.Run("old format", func(t *testing.T) { + reqSearchParams := map[string]interface{}{"metric_type": "L2", "params": map[string]interface{}{"radius": 10}} + searchParams, err := generateSearchParams(reqSearchParams) + assert.Nil(t, err) + assert.Equal(t, 3, len(searchParams)) + for _, kvs := range searchParams { + if kvs.Key == "metric_type" { + assert.Equal(t, "L2", kvs.Value) + } + if kvs.Key == "params" { + var paramsMap map[string]interface{} + err := json.Unmarshal([]byte(kvs.Value), ¶msMap) + assert.Nil(t, err) + assert.Equal(t, 2, len(paramsMap)) + assert.Equal(t, paramsMap["radius"], float64(10)) + assert.Equal(t, paramsMap["metric_type"], "L2") + } + } + }) + + t.Run("new format", func(t *testing.T) { + reqSearchParams := map[string]interface{}{"metric_type": "L2", "radius": 10} + searchParams, err := generateSearchParams(reqSearchParams) + assert.Nil(t, err) + assert.Equal(t, 4, len(searchParams)) + for _, kvs := range searchParams { + if kvs.Key == "metric_type" { + assert.Equal(t, "L2", kvs.Value) + } + if kvs.Key == "radius" { + assert.Equal(t, "10", kvs.Value) + } + if kvs.Key == "params" { + var paramsMap map[string]interface{} + err := json.Unmarshal([]byte(kvs.Value), ¶msMap) + assert.Nil(t, err) + assert.Equal(t, 2, len(paramsMap)) + assert.Equal(t, paramsMap["radius"], float64(10)) + assert.Equal(t, paramsMap["metric_type"], "L2") + } + } + }) +}