enhance: add some apis in Restful (#39105)

- drop/alter database properties
- simplify the structure of search_params
- flush
- compact
- get_compact_status
issue: #38709

---------

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
master
smellthemoon 2025-01-20 18:15:22 +08:00 committed by GitHub
parent 45d49df89b
commit 513b489c85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 450 additions and 43 deletions

View File

@ -68,6 +68,9 @@ const (
AlterAction = "alter"
AlterPropertiesAction = "alter_properties"
DropPropertiesAction = "drop_properties"
CompactAction = "compact"
CompactionStateAction = "get_compaction_state"
FlushAction = "flush"
GetProgressAction = "get_progress" // deprecated, keep it for compatibility, use `/v2/vectordb/jobs/import/describe` instead
AddPrivilegesToGroupAction = "add_privileges_to_group"
RemovePrivilegesFromGroupAction = "remove_privileges_from_group"

View File

@ -82,14 +82,19 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection))))
router.POST(CollectionCategory+AlterPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReqWithProperties{} }, wrapperTraceLog(h.alterCollectionProperties))))
router.POST(CollectionCategory+DropPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &DropCollectionPropertiesReq{} }, wrapperTraceLog(h.dropCollectionProperties))))
router.POST(CollectionCategory+CompactAction, timeoutMiddleware(wrapperPost(func() any { return &CompactReq{} }, wrapperTraceLog(h.compact))))
router.POST(CollectionCategory+CompactionStateAction, timeoutMiddleware(wrapperPost(func() any { return &GetCompactionStateReq{} }, wrapperTraceLog(h.getcompactionState))))
router.POST(CollectionCategory+FlushAction, timeoutMiddleware(wrapperPost(func() any { return &FlushReq{} }, wrapperTraceLog(h.flush))))
router.POST(CollectionFieldCategory+AlterPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionFieldReqWithParams{} }, wrapperTraceLog(h.alterCollectionFieldProperties))))
router.POST(DataBaseCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.createDatabase))))
router.POST(DataBaseCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqRequiredName{} }, wrapperTraceLog(h.dropDatabase))))
router.POST(DataBaseCategory+DropPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &DropDatabasePropertiesReq{} }, wrapperTraceLog(h.dropDatabaseProperties))))
router.POST(DataBaseCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listDatabases))))
router.POST(DataBaseCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqRequiredName{} }, wrapperTraceLog(h.describeDatabase))))
router.POST(DataBaseCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase))))
router.POST(DataBaseCategory+AlterPropertiesAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase))))
// Query
router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &QueryReqV2{
@ -686,7 +691,7 @@ func (h *HandlersV2) dropCollectionProperties(ctx context.Context, c *gin.Contex
req := &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
DeleteKeys: httpReq.DeleteKeys,
DeleteKeys: httpReq.PropertyKeys,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterCollection", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
@ -698,6 +703,62 @@ func (h *HandlersV2) dropCollectionProperties(ctx context.Context, c *gin.Contex
return resp, err
}
func (h *HandlersV2) compact(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*CompactReq)
req := &milvuspb.ManualCompactionRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
MajorCompaction: httpReq.IsClustering,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ManualCompaction", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ManualCompaction(reqCtx, req.(*milvuspb.ManualCompactionRequest))
})
if err == nil {
resp := resp.(*milvuspb.ManualCompactionResponse)
HTTPReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(nil),
HTTPReturnData: gin.H{"compactionID": resp.CompactionID},
})
}
return resp, err
}
func (h *HandlersV2) getcompactionState(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*GetCompactionStateReq)
req := &milvuspb.GetCompactionStateRequest{
CompactionID: httpReq.JobID,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetCompactionState", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.GetCompactionState(reqCtx, req.(*milvuspb.GetCompactionStateRequest))
})
if err == nil {
resp := resp.(*milvuspb.GetCompactionStateResponse)
HTTPReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(nil),
HTTPReturnData: gin.H{"compactionID": httpReq.JobID, "state": resp.State.String(), "executingPlanNumber": resp.ExecutingPlanNo, "timeoutPlanNumber": resp.TimeoutPlanNo, "completedPlanNumber": resp.CompletedPlanNo},
})
}
return resp, err
}
func (h *HandlersV2) flush(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*FlushReq)
req := &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{httpReq.CollectionName},
}
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Flush", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Flush(reqCtx, req.(*milvuspb.FlushRequest))
})
if err == nil {
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
}
return resp, err
}
func (h *HandlersV2) alterCollectionFieldProperties(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*CollectionFieldReqWithParams)
req := &milvuspb.AlterCollectionFieldRequest{
@ -1058,20 +1119,6 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
})
}
func generateSearchParams(reqSearchParams searchParams) []*commonpb.KeyValuePair {
var searchParams []*commonpb.KeyValuePair
if reqSearchParams.Params == nil {
reqSearchParams.Params = make(map[string]any)
}
bs, _ := json.Marshal(reqSearchParams.Params)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.HintsKey, Value: reqSearchParams.Hints})
// need to exposure ParamRoundDecimal in req?
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
return searchParams
}
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*SearchReqV2)
req := &milvuspb.SearchRequest{
@ -1100,7 +1147,15 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
return nil, err
}
searchParams := generateSearchParams(httpReq.SearchParams)
searchParams, err := generateSearchParams(httpReq.SearchParams)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, generate SearchParams failed", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return nil, err
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
if httpReq.GroupByField != "" {
@ -1181,7 +1236,15 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams := generateSearchParams(subReq.SearchParams)
searchParams, err := generateSearchParams(subReq.SearchParams)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, generate SearchParams failed", zap.Error(err))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return nil, err
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
@ -1613,6 +1676,22 @@ func (h *HandlersV2) dropDatabase(ctx context.Context, c *gin.Context, anyReq an
return resp, err
}
func (h *HandlersV2) dropDatabaseProperties(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*DropDatabasePropertiesReq)
req := &milvuspb.AlterDatabaseRequest{
DbName: dbName,
DeleteKeys: httpReq.PropertyKeys,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.AlterDatabase(reqCtx, req.(*milvuspb.AlterDatabaseRequest))
})
if err == nil {
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
}
return resp, err
}
// todo: use a more flexible way to handle the number of input parameters of req
func (h *HandlersV2) listDatabases(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
req := &milvuspb.ListDatabasesRequest{}
@ -2267,7 +2346,7 @@ func (h *HandlersV2) dropIndexProperties(ctx context.Context, c *gin.Context, an
req := &milvuspb.AlterIndexRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
DeleteKeys: httpReq.DeleteKeys,
DeleteKeys: httpReq.PropertyKeys,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {

View File

@ -673,6 +673,126 @@ func TestCreateIndex(t *testing.T) {
}
}
func TestCompact(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().ManualCompaction(mock.Anything, mock.Anything).Return(&milvuspb.ManualCompactionResponse{CompactionID: 1}, nil).Once()
mp.EXPECT().ManualCompaction(mock.Anything, mock.Anything).Return(
&milvuspb.ManualCompactionResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(CollectionCategory, CompactAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"invalid_name"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
mp.EXPECT().GetCompactionState(mock.Anything, mock.Anything).Return(&milvuspb.GetCompactionStateResponse{}, nil).Once()
mp.EXPECT().GetCompactionState(mock.Anything, mock.Anything).Return(
&milvuspb.GetCompactionStateResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
path = versionalV2(CollectionCategory, CompactionStateAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"invalid_name"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}
func TestFlush(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().Flush(mock.Anything, mock.Anything).Return(&milvuspb.FlushResponse{}, nil).Once()
mp.EXPECT().Flush(mock.Anything, mock.Anything).Return(
&milvuspb.FlushResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(CollectionCategory, FlushAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"invalid_name"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}
func TestDatabase(t *testing.T) {
paramtable.Init()
// disable rate limit
@ -784,6 +904,46 @@ func TestDatabase(t *testing.T) {
errCode: 1100, // ErrParameterInvalid
})
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
path = versionalV2(DataBaseCategory, DropPropertiesAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
path = versionalV2(DataBaseCategory, AlterPropertiesAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
@ -841,12 +1001,12 @@ func TestColletcionProperties(t *testing.T) {
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"test", "deleteKeys":["mmap"]}`),
requestBody: []byte(`{"collectionName":"test", "propertyKeys":["mmap"]}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"mock", "deleteKeys":["mmap"]}`),
requestBody: []byte(`{"collectionName":"mock", "propertyKeys":["mmap"]}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
@ -908,12 +1068,12 @@ func TestIndexProperties(t *testing.T) {
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"test","indexName":"test", "deleteKeys":["test"]}`),
requestBody: []byte(`{"collectionName":"test","indexName":"test", "propertyKeys":["test"]}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName":"mock","indexName":"test", "deleteKeys":["test"]}`),
requestBody: []byte(`{"collectionName":"mock","indexName":"test", "propertyKeys":["test"]}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
@ -2205,7 +2365,7 @@ func TestSearchV2(t *testing.T) {
Schema: collSchema,
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(14)
}, nil).Times(15)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
@ -2227,12 +2387,6 @@ func TestSearchV2(t *testing.T) {
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool",
errCode: 1801, // ErrIncorrectParameterFormat
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`),
@ -2410,8 +2564,8 @@ func TestSearchV2(t *testing.T) {
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`),
errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}",
errCode: 1801, // ErrIncorrectParameterFormat
errMsg: "searchParams.params must be a dict: invalid parameter",
errCode: 1100, // ErrParameterInvalid
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,

View File

@ -48,6 +48,13 @@ type DatabaseReqWithProperties struct {
func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName }
type DropDatabasePropertiesReq struct {
DbName string `json:"dbName" binding:"required"`
PropertyKeys []string `json:"propertyKeys"`
}
func (req *DatabaseReqWithProperties) DropDatabasPropertiesReq() string { return req.DbName }
type CollectionNameReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
@ -103,7 +110,7 @@ func (req *RenameCollectionReq) GetDbName() string { return req.DbName }
type DropCollectionPropertiesReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
DeleteKeys []string `json:"deleteKeys"`
PropertyKeys []string `json:"propertyKeys"`
}
func (req *DropCollectionPropertiesReq) GetDbName() string { return req.DbName }
@ -112,6 +119,33 @@ func (req *DropCollectionPropertiesReq) GetCollectionName() string {
return req.CollectionName
}
type CompactReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
IsClustering bool `json:"isClustering"`
}
func (req *CompactReq) GetDbName() string { return req.DbName }
func (req *CompactReq) GetCollectionName() string {
return req.CollectionName
}
type GetCompactionStateReq struct {
JobID int64 `json:"jobID"`
}
type FlushReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
}
func (req *FlushReq) GetDbName() string { return req.DbName }
func (req *FlushReq) GetCollectionName() string {
return req.CollectionName
}
type CollectionFieldReqWithParams struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
@ -217,14 +251,6 @@ type CollectionDataReq struct {
func (req *CollectionDataReq) GetDbName() string { return req.DbName }
type searchParams struct {
// not use metricType any more, just for compatibility
MetricType string `json:"metricType"`
Params map[string]interface{} `json:"params"`
IgnoreGrowing bool `json:"ignoreGrowing"`
Hints string `json:"hints"`
}
type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
@ -238,7 +264,7 @@ type SearchReqV2 struct {
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
SearchParams map[string]interface{} `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
ExprParams map[string]interface{} `json:"exprParams"`
// not use Params any more, just for compatibility
@ -260,7 +286,7 @@ type SubSearchReq struct {
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"params"`
SearchParams map[string]interface{} `json:"params"`
ExprParams map[string]interface{} `json:"exprParams"`
}
@ -426,7 +452,7 @@ type DropIndexPropertiesReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
IndexName string `json:"indexName" binding:"required"`
DeleteKeys []string `json:"deleteKeys"`
PropertyKeys []string `json:"propertyKeys"`
}
func (req *DropIndexPropertiesReq) GetDbName() string { return req.DbName }

View File

@ -1853,3 +1853,64 @@ func WrapErrorToResponse(err error) *milvuspb.BoolResponse {
Status: merr.Status(err),
}
}
// after 2.5.2, all parameters of search_params can be written into one layer
// no more parameters will be written searchParams.params
// to ensure compatibility and milvus can still get a json format parameter
// try to write all the parameters under searchParams into searchParams.Params
func generateSearchParams(reqSearchParams map[string]interface{}) ([]*commonpb.KeyValuePair, error) {
var searchParams []*commonpb.KeyValuePair
var params interface{}
if val, ok := reqSearchParams[Params]; ok {
params = val
}
paramsMap := make(map[string]interface{})
if params != nil {
var ok bool
if paramsMap, ok = params.(map[string]interface{}); !ok {
return nil, merr.WrapErrParameterInvalidMsg("searchParams.params must be a dict")
}
}
deepEqual := func(value1, value2 interface{}) bool {
// try to handle 10.0==10
switch v1 := value1.(type) {
case float64:
if v2, ok := value2.(int); ok {
return v1 == float64(v2)
}
case int:
if v2, ok := value2.(float64); ok {
return float64(v1) == v2
}
}
return reflect.DeepEqual(value1, value2)
}
for key, value := range reqSearchParams {
if val, ok := paramsMap[key]; ok {
if !deepEqual(val, value) {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("ambiguous parameter: %s, in search_param: %v, in search_param.params: %v", key, value, val))
}
} else if key != Params {
paramsMap[key] = value
}
}
bs, _ := json.Marshal(paramsMap)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)})
for key, value := range reqSearchParams {
if key != Params {
// for compatibility
if key == "ignoreGrowing" {
key = common.IgnoreGrowing
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
}
// need to exposure ParamRoundDecimal in req?
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"})
return searchParams, nil
}

View File

@ -2377,3 +2377,87 @@ func TestGenerateExpressionTemplate(t *testing.T) {
assert.Equal(t, actual, ans[i])
}
}
func TestGenerateSearchParams(t *testing.T) {
t.Run("searchParams.params must be a dict", func(t *testing.T) {
reqSearchParams := map[string]interface{}{"params": 0}
_, err := generateSearchParams(reqSearchParams)
assert.NotNil(t, err)
})
t.Run("ambiguous parameter", func(t *testing.T) {
reqSearchParams := map[string]interface{}{"radius": 100, "params": map[string]interface{}{"radius": 10}}
_, err := generateSearchParams(reqSearchParams)
assert.NotNil(t, err)
})
t.Run("no ambiguous parameter", func(t *testing.T) {
reqSearchParams := map[string]interface{}{"radius": 10, "params": map[string]interface{}{"radius": 10.0}}
_, err := generateSearchParams(reqSearchParams)
assert.Nil(t, err)
reqSearchParams = map[string]interface{}{"radius": 10.0, "params": map[string]interface{}{"radius": 10}}
_, err = generateSearchParams(reqSearchParams)
assert.Nil(t, err)
reqSearchParams = map[string]interface{}{"radius": 10, "params": map[string]interface{}{"radius": 10}}
searchParams, err := generateSearchParams(reqSearchParams)
assert.Equal(t, 3, len(searchParams))
assert.Nil(t, err)
for _, kvs := range searchParams {
if kvs.Key == "radius" {
assert.Equal(t, "10", kvs.Value)
}
if kvs.Key == "params" {
var paramsMap map[string]interface{}
err := json.Unmarshal([]byte(kvs.Value), &paramsMap)
assert.Nil(t, err)
assert.Equal(t, 1, len(paramsMap))
assert.Equal(t, paramsMap["radius"], float64(10))
}
}
})
t.Run("old format", func(t *testing.T) {
reqSearchParams := map[string]interface{}{"metric_type": "L2", "params": map[string]interface{}{"radius": 10}}
searchParams, err := generateSearchParams(reqSearchParams)
assert.Nil(t, err)
assert.Equal(t, 3, len(searchParams))
for _, kvs := range searchParams {
if kvs.Key == "metric_type" {
assert.Equal(t, "L2", kvs.Value)
}
if kvs.Key == "params" {
var paramsMap map[string]interface{}
err := json.Unmarshal([]byte(kvs.Value), &paramsMap)
assert.Nil(t, err)
assert.Equal(t, 2, len(paramsMap))
assert.Equal(t, paramsMap["radius"], float64(10))
assert.Equal(t, paramsMap["metric_type"], "L2")
}
}
})
t.Run("new format", func(t *testing.T) {
reqSearchParams := map[string]interface{}{"metric_type": "L2", "radius": 10}
searchParams, err := generateSearchParams(reqSearchParams)
assert.Nil(t, err)
assert.Equal(t, 4, len(searchParams))
for _, kvs := range searchParams {
if kvs.Key == "metric_type" {
assert.Equal(t, "L2", kvs.Value)
}
if kvs.Key == "radius" {
assert.Equal(t, "10", kvs.Value)
}
if kvs.Key == "params" {
var paramsMap map[string]interface{}
err := json.Unmarshal([]byte(kvs.Value), &paramsMap)
assert.Nil(t, err)
assert.Equal(t, 2, len(paramsMap))
assert.Equal(t, paramsMap["radius"], float64(10))
assert.Equal(t, paramsMap["metric_type"], "L2")
}
}
})
}