mirror of https://github.com/milvus-io/milvus.git
enhance: support db request in Restful api (#38140)
#38077 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/38154/head
parent
6f1b1ad78b
commit
e359725530
|
@ -21,8 +21,9 @@ import (
|
|||
"sort"
|
||||
|
||||
"github.com/bits-and-blooms/bitset"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
)
|
||||
|
||||
type Sizable interface {
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
// v2
|
||||
const (
|
||||
// --- category ---
|
||||
DataBaseCategory = "/databases/"
|
||||
CollectionCategory = "/collections/"
|
||||
EntityCategory = "/entities/"
|
||||
PartitionCategory = "/partitions/"
|
||||
|
@ -90,6 +91,8 @@ const (
|
|||
HTTPCollectionName = "collectionName"
|
||||
HTTPCollectionID = "collectionID"
|
||||
HTTPDbName = "dbName"
|
||||
HTTPDbID = "dbID"
|
||||
HTTPProperties = "properties"
|
||||
HTTPPartitionName = "partitionName"
|
||||
HTTPPartitionNames = "partitionNames"
|
||||
HTTPUserName = "userName"
|
||||
|
|
|
@ -78,6 +78,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
|
|||
router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.loadCollection))))
|
||||
router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection))))
|
||||
|
||||
router.POST(DataBaseCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.createDatabase))))
|
||||
router.POST(DataBaseCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.dropDatabase))))
|
||||
router.POST(DataBaseCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listDatabases))))
|
||||
router.POST(DataBaseCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.describeDatabase))))
|
||||
router.POST(DataBaseCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase))))
|
||||
// Query
|
||||
router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
|
||||
return &QueryReqV2{
|
||||
|
@ -207,13 +212,15 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc {
|
|||
return
|
||||
}
|
||||
dbName := ""
|
||||
if getter, ok := req.(requestutil.DBNameGetter); ok {
|
||||
dbName = getter.GetDbName()
|
||||
}
|
||||
if dbName == "" {
|
||||
dbName = c.Request.Header.Get(HTTPHeaderDBName)
|
||||
if req != nil {
|
||||
if getter, ok := req.(requestutil.DBNameGetter); ok {
|
||||
dbName = getter.GetDbName()
|
||||
}
|
||||
if dbName == "" {
|
||||
dbName = DefaultDbName
|
||||
dbName = c.Request.Header.Get(HTTPHeaderDBName)
|
||||
if dbName == "" {
|
||||
dbName = DefaultDbName
|
||||
}
|
||||
}
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
|
@ -277,7 +284,7 @@ func wrapperTraceLog(v2 handlerFuncV2) handlerFuncV2 {
|
|||
if err != nil {
|
||||
log.Ctx(ctx).Info("trace info: all, error", zap.Error(err))
|
||||
} else {
|
||||
log.Ctx(ctx).Info("trace info: all, unknown", zap.Any("resp", resp))
|
||||
log.Ctx(ctx).Info("trace info: all, unknown")
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
|
@ -1149,7 +1156,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||
var err error
|
||||
fieldNames := map[string]bool{}
|
||||
partitionsNum := int64(-1)
|
||||
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
|
||||
if len(httpReq.Schema.Fields) == 0 {
|
||||
if len(httpReq.Schema.Functions) > 0 {
|
||||
err := merr.WrapErrParameterInvalid("schema", "functions",
|
||||
"functions are not supported for quickly create collection")
|
||||
|
@ -1468,6 +1475,99 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||
return statusResponse, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) createDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*DatabaseReqWithProperties)
|
||||
req := &milvuspb.CreateDatabaseRequest{
|
||||
DbName: dbName,
|
||||
}
|
||||
properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties))
|
||||
for key, value := range httpReq.Properties {
|
||||
properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||
}
|
||||
req.Properties = properties
|
||||
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateDatabase(reqCtx, req.(*milvuspb.CreateDatabaseRequest))
|
||||
})
|
||||
if err == nil {
|
||||
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) dropDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
req := &milvuspb.DropDatabaseRequest{
|
||||
DbName: dbName,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.DropDatabase(reqCtx, req.(*milvuspb.DropDatabaseRequest))
|
||||
})
|
||||
if err == nil {
|
||||
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// todo: use a more flexible way to handle the number of input parameters of req
|
||||
func (h *HandlersV2) listDatabases(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
req := &milvuspb.ListDatabasesRequest{}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.ListDatabases(reqCtx, req.(*milvuspb.ListDatabasesRequest))
|
||||
})
|
||||
if err == nil {
|
||||
HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListDatabasesResponse).DbNames))
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) describeDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
req := &milvuspb.DescribeDatabaseRequest{
|
||||
DbName: dbName,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeDatabase", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.DescribeDatabase(reqCtx, req.(*milvuspb.DescribeDatabaseRequest))
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, _ := resp.(*milvuspb.DescribeDatabaseResponse)
|
||||
if info.Properties == nil {
|
||||
info.Properties = []*commonpb.KeyValuePair{}
|
||||
}
|
||||
dataBaseInfo := map[string]any{
|
||||
HTTPDbName: info.DbName,
|
||||
HTTPDbID: info.DbID,
|
||||
HTTPProperties: info.Properties,
|
||||
}
|
||||
HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: dataBaseInfo})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) alterDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*DatabaseReqWithProperties)
|
||||
req := &milvuspb.AlterDatabaseRequest{
|
||||
DbName: dbName,
|
||||
}
|
||||
properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties))
|
||||
for key, value := range httpReq.Properties {
|
||||
properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||
}
|
||||
req.Properties = properties
|
||||
|
||||
c.Set(ContextRequest, req)
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterDatabase", func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.AlterDatabase(reqCtx, req.(*milvuspb.AlterDatabaseRequest))
|
||||
})
|
||||
if err == nil {
|
||||
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter)
|
||||
req := &milvuspb.ShowPartitionsRequest{
|
||||
|
|
|
@ -673,6 +673,135 @@ func TestCreateIndex(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDatabase(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
postTestCases := []requestBodyTestCase{}
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
|
||||
mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(
|
||||
&commonpb.Status{
|
||||
Code: 1100,
|
||||
Reason: "mock",
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
path := versionalV2(DataBaseCategory, CreateAction)
|
||||
// success
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"test"}`),
|
||||
})
|
||||
// mock fail
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"invalid_name"}`),
|
||||
errMsg: "mock",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
|
||||
mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
|
||||
mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(
|
||||
&commonpb.Status{
|
||||
Code: 1100,
|
||||
Reason: "mock",
|
||||
}, nil).Once()
|
||||
path = versionalV2(DataBaseCategory, DropAction)
|
||||
// success
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"test"}`),
|
||||
})
|
||||
// mock fail
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"mock"}`),
|
||||
errMsg: "mock",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
|
||||
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{DbNames: []string{"a", "b", "c"}, DbIds: []int64{100, 101, 102}}, nil).Once()
|
||||
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
Code: 1100,
|
||||
Reason: "mock",
|
||||
},
|
||||
}, nil).Once()
|
||||
path = versionalV2(DataBaseCategory, ListAction)
|
||||
// success
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"test"}`),
|
||||
})
|
||||
// mock fail
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"mock"}`),
|
||||
errMsg: "mock",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
|
||||
mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{DbName: "test", DbID: 100}, nil).Once()
|
||||
mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{
|
||||
Status: &commonpb.Status{
|
||||
Code: 1100,
|
||||
Reason: "mock",
|
||||
},
|
||||
}, nil).Once()
|
||||
path = versionalV2(DataBaseCategory, DescribeAction)
|
||||
// success
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"test"}`),
|
||||
})
|
||||
// mock fail
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"mock"}`),
|
||||
errMsg: "mock",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
|
||||
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
|
||||
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(
|
||||
&commonpb.Status{
|
||||
Code: 1100,
|
||||
Reason: "mock",
|
||||
}, nil).Once()
|
||||
path = versionalV2(DataBaseCategory, AlterAction)
|
||||
// success
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"test"}`),
|
||||
})
|
||||
// mock fail
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"dbName":"mock"}`),
|
||||
errMsg: "mock",
|
||||
errCode: 1100, // ErrParameterInvalid
|
||||
})
|
||||
|
||||
for _, testcase := range postTestCases {
|
||||
t.Run("post"+testcase.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
fmt.Println(w.Body.String())
|
||||
returnBody := &ReturnErrMsg{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), returnBody)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testcase.errCode, returnBody.Code)
|
||||
if testcase.errCode != 0 {
|
||||
assert.Equal(t, testcase.errMsg, returnBody.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCollection(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
|
|
|
@ -25,12 +25,23 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type EmptyReq struct{}
|
||||
|
||||
func (req *EmptyReq) GetDbName() string { return "" }
|
||||
|
||||
type DatabaseReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
}
|
||||
|
||||
func (req *DatabaseReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type DatabaseReqWithProperties struct {
|
||||
DbName string `json:"dbName" binding:"required"`
|
||||
Properties map[string]interface{} `json:"properties"`
|
||||
}
|
||||
|
||||
func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionNameReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
|
|
|
@ -1161,13 +1161,16 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
|
|||
|
||||
func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
|
||||
dbKey := strings.ToLower(util.HeaderDBName)
|
||||
if username == "" {
|
||||
return contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
|
||||
if dbName != "" {
|
||||
ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
|
||||
}
|
||||
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName)
|
||||
if username != "" {
|
||||
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func AppendUserInfoForRPC(ctx context.Context) context.Context {
|
||||
|
|
Loading…
Reference in New Issue