feat: Revise the RESTful bulk insert API (#29698)

Revise the RESTful bulk insert API from version 1 to version 2.

issue: https://github.com/milvus-io/milvus/issues/28521

---------

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/31000/head
yihao.dai 2024-03-05 15:03:00 +08:00 committed by GitHub
parent 8c2615f840
commit 3b66c17279
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 137 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/tidwall/gjson"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
@ -21,11 +22,13 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/requestutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -128,8 +131,8 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.alterAlias)))))
router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listImportJob)))))
router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DataFilesReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob)))))
router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &TaskIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess)))))
router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob)))))
router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess)))))
}
type (
@ -1536,43 +1539,27 @@ func (h *HandlersV2) alterAlias(ctx context.Context, c *gin.Context, anyReq any,
}
func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter)
limitGetter, _ := anyReq.(LimitGetter)
req := &milvuspb.ListImportTasksRequest{
CollectionName: collectionGetter.GetCollectionName(),
Limit: int64(limitGetter.GetLimit()),
collectionGetter := anyReq.(requestutil.CollectionNameGetter)
req := &internalpb.ListImportsRequest{
DbName: dbName,
CollectionName: collectionGetter.GetCollectionName(),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListImportTasks(reqCtx, req.(*milvuspb.ListImportTasksRequest))
return h.proxy.ListImports(reqCtx, req.(*internalpb.ListImportsRequest))
})
if err == nil {
returnData := []map[string]interface{}{}
for _, job := range resp.(*milvuspb.ListImportTasksResponse).Tasks {
taskDetail := map[string]interface{}{
"taskID": job.Id,
"state": job.State.String(),
"dbName": dbName,
"collectionName": collectionGetter.GetCollectionName(),
"createTimestamp": strconv.FormatInt(job.CreateTs, 10),
returnData := make([]map[string]interface{}, 0)
response := resp.(*internalpb.ListImportsResponse)
for i, jobID := range response.GetJobIDs() {
jobDetail := make(map[string]interface{})
jobDetail["jobID"] = jobID
jobDetail["state"] = response.GetStates()[i].String()
jobDetail["progress"] = response.GetProgresses()[i]
reason := response.GetReasons()[i]
if reason != "" {
jobDetail["reason"] = reason
}
for _, info := range job.Infos {
switch info.Key {
case "collection":
taskDetail["collectionName"] = info.Value
case "partition":
taskDetail["partitionName"] = info.Value
case "persist_cost":
taskDetail["persistCost"] = info.Value
case "progress_percent":
taskDetail["progressPercent"] = info.Value
case "failed_reason":
if info.Value != "" {
taskDetail[HTTPReturnIndexFailReason] = info.Value
}
}
}
returnData = append(returnData, taskDetail)
returnData = append(returnData, jobDetail)
}
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: returnData})
}
@ -1580,53 +1567,50 @@ func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq a
}
func (h *HandlersV2) createImportJob(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter)
fileNamesGetter, _ := anyReq.(FileNamesGetter)
req := &milvuspb.ImportRequest{
CollectionName: collectionGetter.GetCollectionName(),
var (
collectionGetter = anyReq.(requestutil.CollectionNameGetter)
partitionGetter = anyReq.(requestutil.PartitionNameGetter)
filesGetter = anyReq.(FilesGetter)
optionsGetter = anyReq.(OptionsGetter)
)
req := &internalpb.ImportRequest{
DbName: dbName,
Files: fileNamesGetter.GetFileNames(),
CollectionName: collectionGetter.GetCollectionName(),
PartitionName: partitionGetter.GetPartitionName(),
Files: lo.Map(filesGetter.GetFiles(), func(paths []string, _ int) *internalpb.ImportFile {
return &internalpb.ImportFile{Paths: paths}
}),
Options: funcutil.Map2KeyValuePair(optionsGetter.GetOptions()),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Import(reqCtx, req.(*milvuspb.ImportRequest))
return h.proxy.ImportV2(reqCtx, req.(*internalpb.ImportRequest))
})
if err == nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: resp.(*milvuspb.ImportResponse).Tasks})
returnData := make(map[string]interface{})
returnData["jobID"] = resp.(*internalpb.ImportResponse).GetJobID()
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: returnData})
}
return resp, err
}
func (h *HandlersV2) getImportJobProcess(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
taskIDGetter, _ := anyReq.(TaskIDGetter)
req := &milvuspb.GetImportStateRequest{
Task: taskIDGetter.GetTaskID(),
jobIDGetter := anyReq.(JobIDGetter)
req := &internalpb.GetImportProgressRequest{
DbName: dbName,
JobID: jobIDGetter.GetJobID(),
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.GetImportState(reqCtx, req.(*milvuspb.GetImportStateRequest))
return h.proxy.GetImportProgress(reqCtx, req.(*internalpb.GetImportProgressRequest))
})
if err == nil {
response := resp.(*milvuspb.GetImportStateResponse)
returnData := map[string]interface{}{
"taskID": response.Id,
"state": response.State.String(),
"dbName": dbName,
"createTimestamp": strconv.FormatInt(response.CreateTs, 10),
}
for _, info := range response.Infos {
switch info.Key {
case "collection":
returnData["collectionName"] = info.Value
case "partition":
returnData["partitionName"] = info.Value
case "persist_cost":
returnData["persistCost"] = info.Value
case "progress_percent":
returnData["progressPercent"] = info.Value
case "failed_reason":
if info.Value != "" {
returnData[HTTPReturnIndexFailReason] = info.Value
}
}
response := resp.(*internalpb.GetImportProgressResponse)
returnData := make(map[string]interface{})
returnData["jobID"] = jobIDGetter.GetJobID()
returnData["state"] = response.GetState().String()
returnData["progress"] = response.GetProgress()
reason := response.GetReason()
if reason != "" {
returnData["reason"] = reason
}
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: returnData})
}

View File

@ -18,6 +18,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util"
@ -52,6 +53,10 @@ func (DefaultReq) GetBase() *commonpb.MsgBase {
func (req *DefaultReq) GetDbName() string { return req.DbName }
func init() {
paramtable.Init()
}
func TestHTTPWrapper(t *testing.T) {
postTestCases := []requestBodyTestCase{}
postTestCasesTrace := []requestBodyTestCase{}
@ -678,46 +683,6 @@ func TestMethodGet(t *testing.T) {
Status: &StatusSuccess,
Alias: DefaultAliasName,
}, nil).Once()
mp.EXPECT().ListImportTasks(mock.Anything, mock.Anything).Return(&milvuspb.ListImportTasksResponse{
Status: &StatusSuccess,
Tasks: []*milvuspb.GetImportStateResponse{
{
Status: &StatusSuccess,
State: 6,
Infos: []*commonpb.KeyValuePair{
{Key: "collection", Value: DefaultCollectionName},
{Key: "partition", Value: DefaultPartitionName},
{Key: "persist_cost", Value: "0.23"},
{Key: "progress_percent", Value: "100"},
{Key: "failed_reason"},
},
Id: 1234567890,
},
{
Status: &StatusSuccess,
State: 0,
Infos: []*commonpb.KeyValuePair{
{Key: "collection", Value: DefaultCollectionName},
{Key: "partition", Value: DefaultPartitionName},
{Key: "progress_percent", Value: "0"},
{Key: "failed_reason", Value: "failed to get file size of "},
},
Id: 123456789,
},
},
}, nil).Once()
mp.EXPECT().GetImportState(mock.Anything, mock.Anything).Return(&milvuspb.GetImportStateResponse{
Status: &StatusSuccess,
State: 6,
Infos: []*commonpb.KeyValuePair{
{Key: "collection", Value: DefaultCollectionName},
{Key: "partition", Value: DefaultPartitionName},
{Key: "persist_cost", Value: "0.23"},
{Key: "progress_percent", Value: "100"},
{Key: "failed_reason"},
},
Id: 1234567890,
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []rawTestCase{}
@ -791,12 +756,6 @@ func TestMethodGet(t *testing.T) {
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(AliasCategory, DescribeAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(ImportJobCategory, ListAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(ImportJobCategory, GetProgressAction),
})
for _, testcase := range queryTestCases {
t.Run("query", func(t *testing.T) {
@ -806,8 +765,7 @@ func TestMethodGet(t *testing.T) {
`"indexName": "` + DefaultIndexName + `",` +
`"userName": "` + util.UserRoot + `",` +
`"roleName": "` + util.RoleAdmin + `",` +
`"aliasName": "` + DefaultAliasName + `",` +
`"taskID": 1234567890` +
`"aliasName": "` + DefaultAliasName + `"` +
`}`))
req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader)
w := httptest.NewRecorder()
@ -907,7 +865,27 @@ func TestMethodPost(t *testing.T) {
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
mp.EXPECT().CreateAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().AlterAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().Import(mock.Anything, mock.Anything).Return(&milvuspb.ImportResponse{Status: commonSuccessStatus, Tasks: []int64{int64(1234567890)}}, nil).Once()
mp.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(&internalpb.ImportResponse{
Status: commonSuccessStatus, JobID: "1234567890",
}, nil).Once()
mp.EXPECT().ListImports(mock.Anything, mock.Anything).Return(&internalpb.ListImportsResponse{
Status: &StatusSuccess,
JobIDs: []string{"1", "2", "3", "4"},
States: []internalpb.ImportJobState{
internalpb.ImportJobState_Pending,
internalpb.ImportJobState_Importing,
internalpb.ImportJobState_Failed,
internalpb.ImportJobState_Completed,
},
Reasons: []string{"", "", "mock reason", ""},
Progresses: []int64{0, 30, 0, 100},
}, nil).Once()
mp.EXPECT().GetImportProgress(mock.Anything, mock.Anything).Return(&internalpb.GetImportProgressResponse{
Status: &StatusSuccess,
State: internalpb.ImportJobState_Completed,
Reason: "",
Progress: 100,
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []rawTestCase{}
queryTestCases = append(queryTestCases, rawTestCase{
@ -969,6 +947,12 @@ func TestMethodPost(t *testing.T) {
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(ImportJobCategory, CreateAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(ImportJobCategory, ListAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(ImportJobCategory, GetProgressAction),
})
for _, testcase := range queryTestCases {
t.Run("query", func(t *testing.T) {
@ -980,7 +964,8 @@ func TestMethodPost(t *testing.T) {
`"userName": "` + util.UserRoot + `", "password": "Milvus", "newPassword": "milvus", "roleName": "` + util.RoleAdmin + `",` +
`"roleName": "` + util.RoleAdmin + `", "objectType": "Global", "objectName": "*", "privilege": "*",` +
`"aliasName": "` + DefaultAliasName + `",` +
`"files": ["book.json"]` +
`"jobID": "1234567890",` +
`"files": [["book.json"]]` +
`}`))
req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader)
w := httptest.NewRecorder()

View File

@ -18,7 +18,6 @@ func (req *DatabaseReq) GetDbName() string { return req.DbName }
type CollectionNameReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Limit int32 `json:"limit"` // list import jobs
PartitionNames []string `json:"partitionNames"` // get partitions load state
}
@ -30,10 +29,6 @@ func (req *CollectionNameReq) GetCollectionName() string {
return req.CollectionName
}
func (req *CollectionNameReq) GetLimit() int32 {
return req.Limit
}
func (req *CollectionNameReq) GetPartitionNames() []string {
return req.PartitionNames
}
@ -58,29 +53,39 @@ func (req *PartitionReq) GetDbName() string { return req.DbName }
func (req *PartitionReq) GetCollectionName() string { return req.CollectionName }
func (req *PartitionReq) GetPartitionName() string { return req.PartitionName }
type DataFilesReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Files []string `json:"files" binding:"required"`
type ImportReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Files [][]string `json:"files" binding:"required"`
Options map[string]string `json:"options"`
}
func (req *DataFilesReq) GetDbName() string {
func (req *ImportReq) GetDbName() string {
return req.DbName
}
func (req *DataFilesReq) GetCollectionName() string {
func (req *ImportReq) GetCollectionName() string {
return req.CollectionName
}
func (req *DataFilesReq) GetFileNames() []string {
func (req *ImportReq) GetPartitionName() string {
return req.PartitionName
}
func (req *ImportReq) GetFiles() [][]string {
return req.Files
}
type TaskIDReq struct {
TaskID int64 `json:"taskID" binding:"required"`
func (req *ImportReq) GetOptions() map[string]string {
return req.Options
}
func (req *TaskIDReq) GetTaskID() int64 { return req.TaskID }
type JobIDReq struct {
JobID string `json:"jobID" binding:"required"`
}
func (req *JobIDReq) GetJobID() string { return req.JobID }
type QueryReqV2 struct {
DbName string `json:"dbName"`
@ -203,14 +208,14 @@ type IndexNameGetter interface {
type AliasNameGetter interface {
GetAliasName() string
}
type LimitGetter interface {
GetLimit() int32
type FilesGetter interface {
GetFiles() [][]string
}
type FileNamesGetter interface {
GetFileNames() []string
type OptionsGetter interface {
GetOptions() map[string]string
}
type TaskIDGetter interface {
GetTaskID() int64
type JobIDGetter interface {
GetJobID() string
}
type PasswordReq struct {

View File

@ -220,6 +220,10 @@ type Proxy interface {
Component
proxypb.ProxyServer
milvuspb.MilvusServiceServer
ImportV2(context.Context, *internalpb.ImportRequest) (*internalpb.ImportResponse, error)
GetImportProgress(context.Context, *internalpb.GetImportProgressRequest) (*internalpb.GetImportProgressResponse, error)
ListImports(context.Context, *internalpb.ListImportsRequest) (*internalpb.ListImportsResponse, error)
}
// ProxyComponent defines the interface of proxy component.