enhance: support db request in Restful api (#38140)

#38077

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
pull/38154/head
smellthemoon 2024-12-03 20:04:41 +08:00 committed by GitHub
parent 6f1b1ad78b
commit e359725530
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 262 additions and 15 deletions

View File

@ -21,8 +21,9 @@ import (
"sort"
"github.com/bits-and-blooms/bitset"
"github.com/milvus-io/milvus/pkg/log"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/log"
)
type Sizable interface {

View File

@ -25,6 +25,7 @@ import (
// v2
const (
// --- category ---
DataBaseCategory = "/databases/"
CollectionCategory = "/collections/"
EntityCategory = "/entities/"
PartitionCategory = "/partitions/"
@ -90,6 +91,8 @@ const (
HTTPCollectionName = "collectionName"
HTTPCollectionID = "collectionID"
HTTPDbName = "dbName"
HTTPDbID = "dbID"
HTTPProperties = "properties"
HTTPPartitionName = "partitionName"
HTTPPartitionNames = "partitionNames"
HTTPUserName = "userName"

View File

@ -78,6 +78,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.loadCollection))))
router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection))))
router.POST(DataBaseCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.createDatabase))))
router.POST(DataBaseCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.dropDatabase))))
router.POST(DataBaseCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listDatabases))))
router.POST(DataBaseCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.describeDatabase))))
router.POST(DataBaseCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase))))
// Query
router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &QueryReqV2{
@ -207,13 +212,15 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc {
return
}
dbName := ""
if getter, ok := req.(requestutil.DBNameGetter); ok {
dbName = getter.GetDbName()
}
if dbName == "" {
dbName = c.Request.Header.Get(HTTPHeaderDBName)
if req != nil {
if getter, ok := req.(requestutil.DBNameGetter); ok {
dbName = getter.GetDbName()
}
if dbName == "" {
dbName = DefaultDbName
dbName = c.Request.Header.Get(HTTPHeaderDBName)
if dbName == "" {
dbName = DefaultDbName
}
}
}
username, _ := c.Get(ContextUsername)
@ -277,7 +284,7 @@ func wrapperTraceLog(v2 handlerFuncV2) handlerFuncV2 {
if err != nil {
log.Ctx(ctx).Info("trace info: all, error", zap.Error(err))
} else {
log.Ctx(ctx).Info("trace info: all, unknown", zap.Any("resp", resp))
log.Ctx(ctx).Info("trace info: all, unknown")
}
}
return resp, err
@ -1149,7 +1156,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
var err error
fieldNames := map[string]bool{}
partitionsNum := int64(-1)
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
if len(httpReq.Schema.Fields) == 0 {
if len(httpReq.Schema.Functions) > 0 {
err := merr.WrapErrParameterInvalid("schema", "functions",
"functions are not supported for quickly create collection")
@ -1468,6 +1475,99 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
return statusResponse, err
}
func (h *HandlersV2) createDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*DatabaseReqWithProperties)
req := &milvuspb.CreateDatabaseRequest{
DbName: dbName,
}
properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties))
for key, value := range httpReq.Properties {
properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
req.Properties = properties
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateDatabase(reqCtx, req.(*milvuspb.CreateDatabaseRequest))
})
if err == nil {
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
}
return resp, err
}
func (h *HandlersV2) dropDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
req := &milvuspb.DropDatabaseRequest{
DbName: dbName,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropDatabase(reqCtx, req.(*milvuspb.DropDatabaseRequest))
})
if err == nil {
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
}
return resp, err
}
// todo: use a more flexible way to handle the number of input parameters of req
func (h *HandlersV2) listDatabases(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
req := &milvuspb.ListDatabasesRequest{}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListDatabases(reqCtx, req.(*milvuspb.ListDatabasesRequest))
})
if err == nil {
HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListDatabasesResponse).DbNames))
}
return resp, err
}
func (h *HandlersV2) describeDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
req := &milvuspb.DescribeDatabaseRequest{
DbName: dbName,
}
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeDatabase", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DescribeDatabase(reqCtx, req.(*milvuspb.DescribeDatabaseRequest))
})
if err != nil {
return nil, err
}
info, _ := resp.(*milvuspb.DescribeDatabaseResponse)
if info.Properties == nil {
info.Properties = []*commonpb.KeyValuePair{}
}
dataBaseInfo := map[string]any{
HTTPDbName: info.DbName,
HTTPDbID: info.DbID,
HTTPProperties: info.Properties,
}
HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: dataBaseInfo})
return resp, err
}
func (h *HandlersV2) alterDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*DatabaseReqWithProperties)
req := &milvuspb.AlterDatabaseRequest{
DbName: dbName,
}
properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties))
for key, value := range httpReq.Properties {
properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
req.Properties = properties
c.Set(ContextRequest, req)
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterDatabase", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.AlterDatabase(reqCtx, req.(*milvuspb.AlterDatabaseRequest))
})
if err == nil {
HTTPReturn(c, http.StatusOK, wrapperReturnDefault())
}
return resp, err
}
func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter)
req := &milvuspb.ShowPartitionsRequest{

View File

@ -673,6 +673,135 @@ func TestCreateIndex(t *testing.T) {
}
}
func TestDatabase(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(DataBaseCategory, CreateAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"invalid_name"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
path = versionalV2(DataBaseCategory, DropAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{DbNames: []string{"a", "b", "c"}, DbIds: []int64{100, 101, 102}}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
path = versionalV2(DataBaseCategory, ListAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{DbName: "test", DbID: 100}, nil).Once()
mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
path = versionalV2(DataBaseCategory, DescribeAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
path = versionalV2(DataBaseCategory, AlterAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})
for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}
func TestCreateCollection(t *testing.T) {
paramtable.Init()
// disable rate limit

View File

@ -25,12 +25,23 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)
type EmptyReq struct{}
func (req *EmptyReq) GetDbName() string { return "" }
type DatabaseReq struct {
DbName string `json:"dbName"`
}
func (req *DatabaseReq) GetDbName() string { return req.DbName }
type DatabaseReqWithProperties struct {
DbName string `json:"dbName" binding:"required"`
Properties map[string]interface{} `json:"properties"`
}
func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName }
type CollectionNameReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`

View File

@ -1161,13 +1161,16 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
dbKey := strings.ToLower(util.HeaderDBName)
if username == "" {
return contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
if dbName != "" {
ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
}
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName)
if username != "" {
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue)
}
return ctx
}
func AppendUserInfoForRPC(ctx context.Context) context.Context {