mirror of https://github.com/milvus-io/milvus.git
Restful api error information (#25823)
Fix get/query without outputFields return all fields include vector field Fix check db exists after authorization Signed-off-by: PowderLi <min.li@zilliz.com>pull/25862/head
parent
c15a165d76
commit
d32bf3922f
|
@ -25,7 +25,6 @@ const (
|
|||
|
||||
HTTPReturnCode = "code"
|
||||
HTTPReturnMessage = "message"
|
||||
HTTPReturnError = "error"
|
||||
HTTPReturnData = "data"
|
||||
|
||||
HTTPReturnFieldName = "name"
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -19,6 +21,27 @@ import (
|
|||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool {
|
||||
if dbName == DefaultDbName {
|
||||
return true
|
||||
}
|
||||
resp, err := h.proxy.ListDatabases(c, &milvuspb.ListDatabasesRequest{})
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return false
|
||||
} else if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(resp.Status.ErrorCode)), HTTPReturnMessage: resp.Status.Reason})
|
||||
return false
|
||||
}
|
||||
for _, db := range resp.DbNames {
|
||||
if db == dbName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrDatabaseNotfound), HTTPReturnMessage: merr.ErrDatabaseNotfound.Error()})
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
req := milvuspb.DescribeCollectionRequest{
|
||||
DbName: dbName,
|
||||
|
@ -27,22 +50,21 @@ func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionN
|
|||
if needAuth {
|
||||
username, ok := c.Get(ContextUsername)
|
||||
if !ok {
|
||||
msg := "the user hasn't authenticate"
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusProxyAuthRequired, HTTPReturnMessage: msg})
|
||||
return nil, errors.New(msg)
|
||||
c.JSON(http.StatusProxyAuthRequired, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
return nil, merr.ErrNeedAuthenticate
|
||||
}
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return nil, authErr
|
||||
}
|
||||
}
|
||||
response, err := h.proxy.DescribeCollection(c, &req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "describe collection " + collectionName + " fail", HTTPReturnError: err.Error()})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return nil, err
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: response.Status.Reason})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
return nil, errors.New(response.Status.Reason)
|
||||
}
|
||||
primaryField, ok := getPrimaryField(response.Schema)
|
||||
|
@ -60,10 +82,10 @@ func (h *Handlers) hasCollection(c *gin.Context, dbName string, collectionName s
|
|||
}
|
||||
response, err := h.proxy.HasCollection(c, &req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check collections " + req.CollectionName + " exists fail", HTTPReturnError: err.Error()})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return false, err
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: "check collections " + req.CollectionName + " exists fail", HTTPReturnError: response.Status.Reason})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
return false, errors.New(response.Status.Reason)
|
||||
} else {
|
||||
return response.Value, nil
|
||||
|
@ -90,14 +112,17 @@ func (h *Handlers) listCollections(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, dbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.ShowCollections(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "show collections fail", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: "show collections fail", HTTPReturnError: response.Status.Reason})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
var collections []string
|
||||
if response.CollectionNames != nil {
|
||||
|
@ -117,11 +142,13 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
|||
VectorField: DefaultVectorFieldName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check your parameters conform to the json format", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, the parameter of create collection is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.Dimension == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName and dimension are both required."})
|
||||
log.Warn("high level restful api, create collection require parameters: [collectionName, dimension], but miss", zap.Any("request", httpReq))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
schema, err := proto.Marshal(&schemapb.CollectionSchema{
|
||||
|
@ -152,7 +179,8 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
|||
EnableDynamicField: EnableDynamic,
|
||||
})
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "marshal collection schema to string", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, marshal collection schema fail", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.CreateCollectionRequest{
|
||||
|
@ -165,15 +193,18 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
resp, err := h.proxy.CreateCollection(c, &req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "create collection " + httpReq.CollectionName + " fail", HTTPReturnError: err.Error()})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return
|
||||
} else if resp.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: resp.ErrorCode, HTTPReturnMessage: "create collection " + httpReq.CollectionName + " fail", HTTPReturnError: resp.Reason})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(resp.ErrorCode)), HTTPReturnMessage: resp.Reason})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -184,10 +215,10 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
|||
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}},
|
||||
})
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "create index for collection " + httpReq.CollectionName + " fail, after the collection was created", HTTPReturnError: err.Error()})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return
|
||||
} else if resp.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: resp.ErrorCode, HTTPReturnMessage: "create index for collection " + httpReq.CollectionName + " fail, after the collection was created", HTTPReturnError: resp.Reason})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(resp.ErrorCode)), HTTPReturnMessage: resp.Reason})
|
||||
return
|
||||
}
|
||||
resp, err = h.proxy.LoadCollection(c, &milvuspb.LoadCollectionRequest{
|
||||
|
@ -195,10 +226,10 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
|||
CollectionName: httpReq.CollectionName,
|
||||
})
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "load collection " + httpReq.CollectionName + " fail, after the index was created", HTTPReturnError: err.Error()})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return
|
||||
} else if resp.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: resp.ErrorCode, HTTPReturnMessage: "load collection " + httpReq.CollectionName + " fail, after the index was created", HTTPReturnError: resp.Reason})
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(resp.ErrorCode)), HTTPReturnMessage: resp.Reason})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
|
||||
|
@ -207,12 +238,16 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
|||
func (h *Handlers) getCollectionDetails(c *gin.Context) {
|
||||
collectionName := c.Query(HTTPCollectionName)
|
||||
if collectionName == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName is required."})
|
||||
log.Warn("high level restful api, desc collection require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
dbName := c.DefaultQuery(HTTPDbName, DefaultDbName)
|
||||
if !h.checkDatabase(c, dbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, dbName, collectionName, true)
|
||||
if err != nil {
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
stateResp, stateErr := h.proxy.GetLoadState(c, &milvuspb.GetLoadStateRequest{
|
||||
|
@ -265,19 +300,13 @@ func (h *Handlers) dropCollection(c *gin.Context) {
|
|||
DbName: DefaultDbName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check your parameters conform to the json format", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, the parameter of drop collection is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName is required."})
|
||||
return
|
||||
}
|
||||
has, err := h.hasCollection(c, httpReq.DbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !has {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "can't find collection: " + httpReq.CollectionName})
|
||||
log.Warn("high level restful api, drop collection require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.DropCollectionRequest{
|
||||
|
@ -287,14 +316,25 @@ func (h *Handlers) dropCollection(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
has, err := h.hasCollection(c, httpReq.DbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !has {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()})
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.DropCollection(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "drop collection " + httpReq.CollectionName + " fail", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: response.ErrorCode, HTTPReturnMessage: "drop collection " + httpReq.CollectionName + " fail", HTTPReturnError: response.Reason})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
|
||||
}
|
||||
|
@ -302,16 +342,17 @@ func (h *Handlers) dropCollection(c *gin.Context) {
|
|||
|
||||
func (h *Handlers) query(c *gin.Context) {
|
||||
httpReq := QueryReq{
|
||||
DbName: DefaultDbName,
|
||||
Limit: 100,
|
||||
OutputFields: []string{DefaultOutputFields},
|
||||
DbName: DefaultDbName,
|
||||
Limit: 100,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check your parameters conform to the json format", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, the parameter of query is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName is required."})
|
||||
log.Warn("high level restful api, query require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.QueryRequest{
|
||||
|
@ -331,18 +372,34 @@ func (h *Handlers) query(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
if req.OutputFields == nil {
|
||||
req.OutputFields = []string{DefaultOutputFields}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
for _, field := range coll.Schema.Fields {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
||||
req.OutputFields = append(req.OutputFields, field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
response, err := h.proxy.Query(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "query fail", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: response.Status.Reason})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "show result by row wrong", "originData": response.FieldsData, HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, fail to deal with query result", zap.Any("response", response), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
|
||||
}
|
||||
|
@ -351,15 +408,31 @@ func (h *Handlers) query(c *gin.Context) {
|
|||
|
||||
func (h *Handlers) get(c *gin.Context) {
|
||||
httpReq := GetReq{
|
||||
DbName: DefaultDbName,
|
||||
OutputFields: []string{DefaultOutputFields},
|
||||
DbName: DefaultDbName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check your parameters conform to the json format", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, the parameter of get is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName is required."})
|
||||
log.Warn("high level restful api, get require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters})
|
||||
return
|
||||
}
|
||||
req := milvuspb.QueryRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
|
@ -369,31 +442,28 @@ func (h *Handlers) get(c *gin.Context) {
|
|||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "make sure the collection's primary field", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.QueryRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Expr: filter,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
req.Expr = filter
|
||||
if req.OutputFields == nil {
|
||||
req.OutputFields = []string{DefaultOutputFields}
|
||||
for _, field := range coll.Schema.Fields {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
||||
req.OutputFields = append(req.OutputFields, field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
response, err := h.proxy.Query(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "query fail", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: response.Status.Reason})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "show result by row wrong", "originData": response.FieldsData, HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, fail to deal with get result", zap.Any("response", response), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
|
||||
log.Error("get resultIS: ", zap.Any("res", outputData))
|
||||
|
@ -406,11 +476,26 @@ func (h *Handlers) delete(c *gin.Context) {
|
|||
DbName: DefaultDbName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check your parameters conform to the json format", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, the parameter of delete is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName is required."})
|
||||
log.Warn("high level restful api, delete require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.DeleteRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
|
@ -420,25 +505,15 @@ func (h *Handlers) delete(c *gin.Context) {
|
|||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "make sure the collection's primary field", HTTPReturnError: err.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.DeleteRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Expr: filter,
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
|
||||
return
|
||||
}
|
||||
req.Expr = filter
|
||||
response, err := h.proxy.Delete(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "delete fail", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: "delete fail", HTTPReturnError: response.Status.Reason})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
|
||||
}
|
||||
|
@ -449,11 +524,28 @@ func (h *Handlers) insert(c *gin.Context) {
|
|||
DbName: DefaultDbName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check your parameters conform to the json format", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName is required."})
|
||||
log.Warn("high level restful api, insert require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.InsertRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
PartitionName: "_default",
|
||||
NumRows: uint32(len(httpReq.Data)),
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
|
@ -463,31 +555,21 @@ func (h *Handlers) insert(c *gin.Context) {
|
|||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
err = checkAndSetData(string(body.([]byte)), coll, &httpReq)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "checkout your params", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.InsertRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
PartitionName: "_default",
|
||||
NumRows: uint32(len(httpReq.Data)),
|
||||
}
|
||||
req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "insert data by column wrong", HTTPReturnError: err.Error()})
|
||||
return
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()})
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Insert(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "insert fail", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: "insert fail", HTTPReturnError: response.Status.Reason})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
switch response.IDs.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
|
@ -495,7 +577,7 @@ func (h *Handlers) insert(c *gin.Context) {
|
|||
case *schemapb.IDs_StrId:
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}})
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "ids' type neither int or string"})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -506,11 +588,13 @@ func (h *Handlers) search(c *gin.Context) {
|
|||
Limit: 100,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "check your parameters conform to the json format", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, the parameter of search is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "collectionName is required."})
|
||||
log.Warn("high level restful api, search require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -538,24 +622,35 @@ func (h *Handlers) search(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), &req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusUnauthorized, HTTPReturnMessage: authErr.Error()})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Search(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "search fail", HTTPReturnError: err.Error()})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: response.Status.ErrorCode, HTTPReturnMessage: "search fail", HTTPReturnError: response.Status.Reason})
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
if response.Results.TopK == int64(0) {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}})
|
||||
} else {
|
||||
outputData, err := buildQueryResp(response.Results.TopK, response.Results.OutputFields, response.Results.FieldsData, response.Results.Ids, response.Results.Scores)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusBadRequest, HTTPReturnMessage: "show result by row wrong", HTTPReturnError: err.Error()})
|
||||
log.Warn("high level restful api, fail to deal with search result", zap.Any("result", response.Results), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func code(code int32) int32 {
|
||||
return merr.RootReasonCodeMask & code
|
||||
}
|
||||
func Code(err error) int32 {
|
||||
return code(merr.Code(err))
|
||||
}
|
||||
|
|
|
@ -3,15 +3,20 @@ package httpserver
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proxy"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
@ -100,9 +105,9 @@ func genAuthMiddleWare(needAuth bool) gin.HandlerFunc {
|
|||
return func(c *gin.Context) {
|
||||
username, password, ok := ParseUsernamePassword(c)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{"code": http.StatusProxyAuthRequired, "message": proxy.ErrUnauthenticated().Error()})
|
||||
c.AbortWithStatusJSON(http.StatusProxyAuthRequired, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
} else if username == util.UserRoot && password != util.DefaultRootPassword {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{"code": http.StatusProxyAuthRequired, "message": proxy.ErrUnauthenticated().Error()})
|
||||
c.AbortWithStatusJSON(http.StatusProxyAuthRequired, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
} else {
|
||||
c.Set(ContextUsername, username)
|
||||
}
|
||||
|
@ -113,6 +118,14 @@ func genAuthMiddleWare(needAuth bool) gin.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func Print(code int32, message string) string {
|
||||
return fmt.Sprintf("{\"%s\":%d,\"%s\":\"%s\"}", HTTPReturnCode, code, HTTPReturnMessage, message)
|
||||
}
|
||||
|
||||
func PrintErr(err error) string {
|
||||
return Print(Code(err), err.Error())
|
||||
}
|
||||
|
||||
func TestVectorAuthenticate(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp.EXPECT().
|
||||
|
@ -126,8 +139,8 @@ func TestVectorAuthenticate(t *testing.T) {
|
|||
req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil)
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":407,\"message\":\"auth check failure, please check username and password are correct\"}")
|
||||
assert.Equal(t, w.Code, http.StatusProxyAuthRequired)
|
||||
assert.Equal(t, w.Body.String(), PrintErr(merr.ErrNeedAuthenticate))
|
||||
})
|
||||
|
||||
t.Run("username or password incorrect", func(t *testing.T) {
|
||||
|
@ -135,8 +148,8 @@ func TestVectorAuthenticate(t *testing.T) {
|
|||
req.SetBasicAuth(util.UserRoot, util.UserRoot)
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":407,\"message\":\"auth check failure, please check username and password are correct\"}")
|
||||
assert.Equal(t, w.Code, http.StatusProxyAuthRequired)
|
||||
assert.Equal(t, w.Body.String(), PrintErr(merr.ErrNeedAuthenticate))
|
||||
})
|
||||
|
||||
t.Run("root's password correct", func(t *testing.T) {
|
||||
|
@ -166,7 +179,7 @@ func TestVectorListCollection(t *testing.T) {
|
|||
name: "show collections fail",
|
||||
mp: mp0,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"show collections fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
reason := "cannot create folder"
|
||||
|
@ -181,7 +194,7 @@ func TestVectorListCollection(t *testing.T) {
|
|||
name: "show collections fail",
|
||||
mp: mp1,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":17,\"error\":\"" + reason + "\",\"message\":\"show collections fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_CannotCreateFolder), reason),
|
||||
})
|
||||
|
||||
mp := mocks.NewProxy(t)
|
||||
|
@ -269,7 +282,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":400,\"message\":\"collectionName is required.\"}")
|
||||
assert.Equal(t, w.Body.String(), PrintErr(merr.ErrMissingRequiredParameters))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -282,7 +295,7 @@ func TestVectorCreateCollection(t *testing.T) {
|
|||
name: "create collection fail",
|
||||
mp: mp1,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"create collection " + DefaultCollectionName + " fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
reason := "collection " + DefaultCollectionName + " already exists"
|
||||
|
@ -295,7 +308,7 @@ func TestVectorCreateCollection(t *testing.T) {
|
|||
name: "create collection fail",
|
||||
mp: mp2,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":18,\"error\":\"" + reason + "\",\"message\":\"create collection " + DefaultCollectionName + " fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_CannotCreateFile), reason),
|
||||
})
|
||||
|
||||
mp3 := mocks.NewProxy(t)
|
||||
|
@ -305,7 +318,7 @@ func TestVectorCreateCollection(t *testing.T) {
|
|||
name: "create index fail",
|
||||
mp: mp3,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"create index for collection " + DefaultCollectionName + " fail, after the collection was created\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
mp4 := mocks.NewProxy(t)
|
||||
|
@ -316,7 +329,7 @@ func TestVectorCreateCollection(t *testing.T) {
|
|||
name: "load collection fail",
|
||||
mp: mp4,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"load collection " + DefaultCollectionName + " fail, after the index was created\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
mp5 := mocks.NewProxy(t)
|
||||
|
@ -355,13 +368,13 @@ func TestVectorDropCollection(t *testing.T) {
|
|||
mp1, _ = wrapWithHasCollection(t, mp1, ReturnTrue, 1, nil)
|
||||
mp1.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(nil, ErrDefault).Once()
|
||||
testCases = append(testCases, testCase{
|
||||
name: "create collection fail",
|
||||
name: "drop collection fail",
|
||||
mp: mp1,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"drop collection " + DefaultCollectionName + " fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
reason := "collection " + DefaultCollectionName + " already exists"
|
||||
reason := "cannot find collection " + DefaultCollectionName
|
||||
mp2 := mocks.NewProxy(t)
|
||||
mp2, _ = wrapWithHasCollection(t, mp2, ReturnTrue, 1, nil)
|
||||
mp2.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{
|
||||
|
@ -369,10 +382,10 @@ func TestVectorDropCollection(t *testing.T) {
|
|||
Reason: reason,
|
||||
}, nil).Once()
|
||||
testCases = append(testCases, testCase{
|
||||
name: "create collection fail",
|
||||
name: "drop collection fail",
|
||||
mp: mp2,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":4,\"error\":\"" + reason + "\",\"message\":\"drop collection " + DefaultCollectionName + " fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_CollectionNotExists), reason),
|
||||
})
|
||||
|
||||
mp3 := mocks.NewProxy(t)
|
||||
|
@ -404,39 +417,39 @@ func TestQuery(t *testing.T) {
|
|||
testCases := []testCase{}
|
||||
|
||||
mp2 := mocks.NewProxy(t)
|
||||
mp2, _ = wrapWithDescribeColl(t, mp2, ReturnSuccess, 1, nil)
|
||||
mp2.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, ErrDefault).Twice()
|
||||
mp2, _ = wrapWithDescribeColl(t, mp2, ReturnSuccess, 2, nil)
|
||||
mp2.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, ErrDefault).Times(3)
|
||||
testCases = append(testCases, testCase{
|
||||
name: "query fail",
|
||||
mp: mp2,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"query fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
reason := DefaultCollectionName + " name not found"
|
||||
mp3 := mocks.NewProxy(t)
|
||||
mp3, _ = wrapWithDescribeColl(t, mp3, ReturnSuccess, 1, nil)
|
||||
mp3, _ = wrapWithDescribeColl(t, mp3, ReturnSuccess, 2, nil)
|
||||
mp3.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, // 28
|
||||
Reason: reason,
|
||||
},
|
||||
}, nil).Twice()
|
||||
}, nil).Times(3)
|
||||
testCases = append(testCases, testCase{
|
||||
name: "query fail",
|
||||
mp: mp3,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":28,\"message\":\"" + reason + "\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason),
|
||||
})
|
||||
|
||||
mp4 := mocks.NewProxy(t)
|
||||
mp4, _ = wrapWithDescribeColl(t, mp4, ReturnSuccess, 1, nil)
|
||||
mp4, _ = wrapWithDescribeColl(t, mp4, ReturnSuccess, 2, nil)
|
||||
mp4.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{
|
||||
Status: &StatusSuccess,
|
||||
FieldsData: generateFieldData(),
|
||||
CollectionName: DefaultCollectionName,
|
||||
OutputFields: []string{FieldBookID, FieldWordCount, FieldBookIntro},
|
||||
}, nil).Twice()
|
||||
}, nil).Times(3)
|
||||
testCases = append(testCases, testCase{
|
||||
name: "query success",
|
||||
mp: mp4,
|
||||
|
@ -445,7 +458,34 @@ func TestQuery(t *testing.T) {
|
|||
})
|
||||
|
||||
for _, tt := range testCases {
|
||||
reqs := []*http.Request{genQueryRequest(), genGetRequest()}
|
||||
reqs := []*http.Request{genQueryRequest(true), genQueryRequest(false), genGetRequest()}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testEngine := initHTTPServer(tt.mp, true)
|
||||
for _, req := range reqs {
|
||||
req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword)
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, tt.exceptCode)
|
||||
assert.Equal(t, w.Body.String(), tt.expectedBody)
|
||||
resp := map[string]interface{}{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, err, nil)
|
||||
if resp[HTTPReturnCode] == float64(200) {
|
||||
data := resp[HTTPReturnData].([]interface{})
|
||||
rows := generateQueryResult64(false)
|
||||
for i, row := range data {
|
||||
assert.Equal(t, compareRow64(row.(map[string]interface{}), rows[i]), true)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
testCases = []testCase{}
|
||||
_, testCases = wrapWithDescribeColl(t, nil, ReturnFail, 2, testCases)
|
||||
_, testCases = wrapWithDescribeColl(t, nil, ReturnWrongStatus, 2, testCases)
|
||||
for _, tt := range testCases {
|
||||
reqs := []*http.Request{genQueryRequest(false), genGetRequest()}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testEngine := initHTTPServer(tt.mp, true)
|
||||
for _, req := range reqs {
|
||||
|
@ -469,8 +509,13 @@ func TestQuery(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func genQueryRequest() *http.Request {
|
||||
jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "book_id in [1,2,3]"}`)
|
||||
func genQueryRequest(withOutputFields bool) *http.Request {
|
||||
var jsonBody []byte
|
||||
if withOutputFields {
|
||||
jsonBody = []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "book_id in [1,2,3]", "outputFields": ["*"]}`)
|
||||
} else {
|
||||
jsonBody = []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "book_id in [1,2,3]"}`)
|
||||
}
|
||||
bodyReader := bytes.NewReader(jsonBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/vector/query", bodyReader)
|
||||
return req
|
||||
|
@ -495,7 +540,7 @@ func TestDelete(t *testing.T) {
|
|||
name: "delete fail",
|
||||
mp: mp2,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"delete fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
reason := DefaultCollectionName + " name not found"
|
||||
|
@ -511,7 +556,7 @@ func TestDelete(t *testing.T) {
|
|||
name: "delete fail",
|
||||
mp: mp3,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":28,\"error\":\"" + reason + "\",\"message\":\"delete fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason),
|
||||
})
|
||||
|
||||
mp4 := mocks.NewProxy(t)
|
||||
|
@ -556,7 +601,7 @@ func TestInsert(t *testing.T) {
|
|||
name: "insert fail",
|
||||
mp: mp2,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"insert fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
reason := DefaultCollectionName + " name not found"
|
||||
|
@ -572,7 +617,7 @@ func TestInsert(t *testing.T) {
|
|||
name: "insert fail",
|
||||
mp: mp3,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":28,\"error\":\"" + reason + "\",\"message\":\"insert fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason),
|
||||
})
|
||||
|
||||
mp4 := mocks.NewProxy(t)
|
||||
|
@ -584,7 +629,7 @@ func TestInsert(t *testing.T) {
|
|||
name: "id type invalid",
|
||||
mp: mp4,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"message\":\"ids' type neither int or string\"}",
|
||||
expectedBody: PrintErr(merr.ErrCheckPrimaryKey),
|
||||
})
|
||||
|
||||
mp5 := mocks.NewProxy(t)
|
||||
|
@ -646,7 +691,7 @@ func TestInsert(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":400,\"error\":\"data is required\",\"message\":\"checkout your params\"}")
|
||||
assert.Equal(t, w.Body.String(), PrintErr(merr.ErrInvalidInsertData))
|
||||
resp := map[string]interface{}{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, err, nil)
|
||||
|
@ -712,7 +757,7 @@ func TestInsertForDataType(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":400,\"error\":\"not support fieldName field-array dataType Array\",\"message\":\"checkout your params\"}")
|
||||
assert.Equal(t, w.Body.String(), PrintErr(merr.ErrInvalidInsertData))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -748,7 +793,7 @@ func TestSearch(t *testing.T) {
|
|||
name: "search fail",
|
||||
mp: mp2,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"search fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
})
|
||||
|
||||
reason := DefaultCollectionName + " name not found"
|
||||
|
@ -763,7 +808,7 @@ func TestSearch(t *testing.T) {
|
|||
name: "search fail",
|
||||
mp: mp3,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":28,\"error\":\"" + reason + "\",\"message\":\"search fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason),
|
||||
})
|
||||
|
||||
mp4 := mocks.NewProxy(t)
|
||||
|
@ -834,7 +879,7 @@ func wrapWithDescribeColl(t *testing.T, mp *mocks.Proxy, returnType ReturnType,
|
|||
name: "[share] describe coll fail",
|
||||
mp: mp,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"describe collection " + DefaultCollectionName + " fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
}
|
||||
case ReturnWrongStatus:
|
||||
call = mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
|
@ -882,7 +927,7 @@ func wrapWithHasCollection(t *testing.T, mp *mocks.Proxy, returnType ReturnType,
|
|||
name: "[share] collection not found",
|
||||
mp: mp,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"message\":\"can't find collection: " + DefaultCollectionName + "\"}",
|
||||
expectedBody: PrintErr(merr.ErrCollectionNotFound),
|
||||
}
|
||||
case ReturnFail:
|
||||
call = mp.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(nil, ErrDefault)
|
||||
|
@ -890,7 +935,7 @@ func wrapWithHasCollection(t *testing.T, mp *mocks.Proxy, returnType ReturnType,
|
|||
name: "[share] check collection fail",
|
||||
mp: mp,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"check collections " + DefaultCollectionName + " exists fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
}
|
||||
case ReturnWrongStatus:
|
||||
reason := "can't find collection: " + DefaultCollectionName
|
||||
|
@ -904,7 +949,7 @@ func wrapWithHasCollection(t *testing.T, mp *mocks.Proxy, returnType ReturnType,
|
|||
name: "[share] unexpected error",
|
||||
mp: mp,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":1,\"error\":\"" + reason + "\",\"message\":\"check collections book exists fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_UnexpectedError), reason),
|
||||
}
|
||||
}
|
||||
if times == 2 {
|
||||
|
@ -919,9 +964,9 @@ func wrapWithHasCollection(t *testing.T, mp *mocks.Proxy, returnType ReturnType,
|
|||
}
|
||||
|
||||
func TestHttpRequestFormat(t *testing.T) {
|
||||
parseErrStr := "{\"code\":400,\"error\":\"invalid character ',' after object key\",\"message\":\"check your parameters conform to the json format\"}"
|
||||
collnameErrStr := "{\"code\":400,\"message\":\"collectionName is required.\"}"
|
||||
collnameDimErrStr := "{\"code\":400,\"message\":\"collectionName and dimension are both required.\"}"
|
||||
parseErrStr := PrintErr(merr.ErrIncorrectParameterFormat)
|
||||
collnameErrStr := PrintErr(merr.ErrMissingRequiredParameters)
|
||||
collnameDimErrStr := PrintErr(merr.ErrMissingRequiredParameters)
|
||||
dataErrStr := "check and set data"
|
||||
jsons := map[string][]byte{
|
||||
parseErrStr: []byte(`{"collectionName": {"` + DefaultCollectionName + `", "dimension": 2}`),
|
||||
|
@ -969,7 +1014,7 @@ func TestHttpRequestFormat(t *testing.T) {
|
|||
|
||||
func TestAuthorization(t *testing.T) {
|
||||
proxy.Params.CommonCfg.AuthorizationEnabled = true
|
||||
errorStr := "{\"code\":401,\"message\":\"rpc error: code = Unavailable desc = internal: Milvus Proxy is not ready yet. please wait\"}"
|
||||
errorStr := Print(int32(65535), "rpc error: code = Unavailable desc = internal: Milvus Proxy is not ready yet. please wait")
|
||||
jsons := map[string][]byte{
|
||||
errorStr: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "data":[{"book_id":1,"book_intro":[0.1,0.11],"distance":0.01,"word_count":1000},{"book_id":2,"book_intro":[0.2,0.22],"distance":0.04,"word_count":2000},{"book_id":3,"book_intro":[0.3,0.33],"distance":0.09,"word_count":3000}]}`),
|
||||
}
|
||||
|
@ -984,14 +1029,13 @@ func TestAuthorization(t *testing.T) {
|
|||
for _, path := range pathArr {
|
||||
t.Run("proxy is not ready", func(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp, _ = wrapWithDescribeColl(t, mp, ReturnSuccess, 1, nil)
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
bodyReader := bytes.NewReader(jsons[res])
|
||||
req := httptest.NewRequest(http.MethodPost, path, bodyReader)
|
||||
req.Header.Set("authorization", "Bearer test:test")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Code, http.StatusUnauthorized)
|
||||
assert.Equal(t, w.Body.String(), res)
|
||||
})
|
||||
}
|
||||
|
@ -1012,7 +1056,7 @@ func TestAuthorization(t *testing.T) {
|
|||
req.Header.Set("authorization", "Bearer test:test")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Code, http.StatusUnauthorized)
|
||||
assert.Equal(t, w.Body.String(), res)
|
||||
})
|
||||
}
|
||||
|
@ -1027,14 +1071,13 @@ func TestAuthorization(t *testing.T) {
|
|||
for _, path := range pathArr {
|
||||
t.Run("proxy is not ready", func(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp, _ = wrapWithHasCollection(t, mp, ReturnTrue, 1, nil)
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
bodyReader := bytes.NewReader(jsons[res])
|
||||
req := httptest.NewRequest(http.MethodPost, path, bodyReader)
|
||||
req.Header.Set("authorization", "Bearer test:test")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Code, http.StatusUnauthorized)
|
||||
assert.Equal(t, w.Body.String(), res)
|
||||
})
|
||||
}
|
||||
|
@ -1055,7 +1098,7 @@ func TestAuthorization(t *testing.T) {
|
|||
req.Header.Set("authorization", "Bearer test:test")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Code, http.StatusUnauthorized)
|
||||
assert.Equal(t, w.Body.String(), res)
|
||||
})
|
||||
}
|
||||
|
@ -1077,7 +1120,7 @@ func TestAuthorization(t *testing.T) {
|
|||
req.Header.Set("authorization", "Bearer test:test")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, w.Code, http.StatusUnauthorized)
|
||||
assert.Equal(t, w.Body.String(), res)
|
||||
})
|
||||
}
|
||||
|
@ -1085,6 +1128,116 @@ func TestAuthorization(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestDatabaseNotFound(t *testing.T) {
|
||||
t.Run("list database fail", func(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(nil, ErrDefault).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil)
|
||||
req.Header.Set("authorization", "Bearer root:Milvus")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), PrintErr(ErrDefault))
|
||||
})
|
||||
|
||||
t.Run("list database without success code", func(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "",
|
||||
},
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil)
|
||||
req.Header.Set("authorization", "Bearer root:Milvus")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), Print(int32(commonpb.ErrorCode_UnexpectedError), ""))
|
||||
})
|
||||
|
||||
t.Run("list database success", func(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &StatusSuccess,
|
||||
DbNames: []string{"default", "test"},
|
||||
}, nil).Once()
|
||||
mp.EXPECT().
|
||||
ShowCollections(mock.Anything, mock.Anything).
|
||||
Return(&milvuspb.ShowCollectionsResponse{
|
||||
Status: &StatusSuccess,
|
||||
CollectionNames: nil,
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil)
|
||||
req.Header.Set("authorization", "Bearer root:Milvus")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":[]}")
|
||||
})
|
||||
|
||||
errorStr := PrintErr(merr.ErrDatabaseNotfound)
|
||||
paths := map[string][]string{
|
||||
errorStr: {
|
||||
URIPrefix + VectorCollectionsPath + "?dbName=test",
|
||||
URIPrefix + VectorCollectionsDescribePath + "?dbName=test&collectionName=" + DefaultCollectionName,
|
||||
},
|
||||
}
|
||||
for res, pathArr := range paths {
|
||||
for _, path := range pathArr {
|
||||
t.Run("GET dbName", func(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &StatusSuccess,
|
||||
DbNames: []string{"default"},
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
req.Header.Set("authorization", "Bearer root:Milvus")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
requestBody := `{"dbName": "test", "collectionName": "` + DefaultCollectionName + `", "dimension": 2}`
|
||||
paths = map[string][]string{
|
||||
requestBody: {
|
||||
URIPrefix + VectorCollectionsCreatePath,
|
||||
URIPrefix + VectorCollectionsDropPath,
|
||||
URIPrefix + VectorInsertPath,
|
||||
URIPrefix + VectorDeletePath,
|
||||
URIPrefix + VectorQueryPath,
|
||||
URIPrefix + VectorGetPath,
|
||||
URIPrefix + VectorSearchPath,
|
||||
},
|
||||
}
|
||||
for res, pathArr := range paths {
|
||||
for _, path := range pathArr {
|
||||
t.Run("POST dbName", func(t *testing.T) {
|
||||
mp := mocks.NewProxy(t)
|
||||
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: &StatusSuccess,
|
||||
DbNames: []string{"default"},
|
||||
}, nil).Once()
|
||||
testEngine := initHTTPServer(mp, true)
|
||||
bodyReader := bytes.NewReader([]byte(res))
|
||||
req := httptest.NewRequest(http.MethodPost, path, bodyReader)
|
||||
req.Header.Set("authorization", "Bearer root:Milvus")
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), errorStr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wrapWithDescribeIndex(t *testing.T, mp *mocks.Proxy, returnType int, times int, testCases []testCase) (*mocks.Proxy, []testCase) {
|
||||
if mp == nil {
|
||||
mp = mocks.NewProxy(t)
|
||||
|
@ -1106,7 +1259,7 @@ func wrapWithDescribeIndex(t *testing.T, mp *mocks.Proxy, returnType int, times
|
|||
name: "[share] describe index fail",
|
||||
mp: mp,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + DefaultErrorMessage + "\",\"message\":\"find index of collection " + DefaultCollectionName + " fail\"}",
|
||||
expectedBody: PrintErr(ErrDefault),
|
||||
}
|
||||
case ReturnWrongStatus:
|
||||
reason := "index is not exists"
|
||||
|
@ -1118,7 +1271,7 @@ func wrapWithDescribeIndex(t *testing.T, mp *mocks.Proxy, returnType int, times
|
|||
name: "[share] describe index fail",
|
||||
mp: mp,
|
||||
exceptCode: 200,
|
||||
expectedBody: "{\"code\":400,\"error\":\"" + reason + "\",\"message\":\"find index of collection " + DefaultCollectionName + " fail\"}",
|
||||
expectedBody: Print(int32(commonpb.ErrorCode_IndexNotExist), reason),
|
||||
}
|
||||
}
|
||||
if times == 2 {
|
||||
|
@ -1215,7 +1368,7 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) {
|
|||
req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil)
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Code, http.StatusProxyAuthRequired)
|
||||
})
|
||||
|
||||
t.Run("auth fail", func(t *testing.T) {
|
||||
|
@ -1224,22 +1377,22 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) {
|
|||
req.SetBasicAuth("test", util.DefaultRootPassword)
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":401,\"message\":\"rpc error: code = Unavailable desc = internal: Milvus Proxy is not ready yet. please wait\"}")
|
||||
assert.Equal(t, w.Code, http.StatusUnauthorized)
|
||||
assert.Equal(t, w.Body.String(), Print(int32(65535), "rpc error: code = Unavailable desc = internal: Milvus Proxy is not ready yet. please wait"))
|
||||
})
|
||||
|
||||
t.Run("describe collection fail with error", func(t *testing.T) {
|
||||
proxy.Params.CommonCfg.AuthorizationEnabled = false
|
||||
mp.EXPECT().
|
||||
DescribeCollection(mock.Anything, mock.Anything).
|
||||
Return(nil, errors.New("error")).
|
||||
Return(nil, ErrDefault).
|
||||
Once()
|
||||
req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil)
|
||||
req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword)
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, w.Code, http.StatusOK)
|
||||
assert.Equal(t, w.Body.String(), "{\"code\":400,\"error\":\"error\",\"message\":\"describe collection book fail\"}")
|
||||
assert.Equal(t, w.Body.String(), PrintErr(ErrDefault))
|
||||
})
|
||||
|
||||
t.Run("describe collection fail with status code", func(t *testing.T) {
|
||||
|
|
|
@ -30,6 +30,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/management"
|
||||
|
@ -130,7 +132,7 @@ func authenticate(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusProxyAuthRequired, gin.H{"code": http.StatusProxyAuthRequired, "message": proxy.ErrUnauthenticated().Error()})
|
||||
c.AbortWithStatusJSON(http.StatusProxyAuthRequired, gin.H{httpserver.HTTPReturnCode: httpserver.Code(merr.ErrNeedAuthenticate), httpserver.HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
}
|
||||
|
||||
// registerHTTPServer register the http server, panic when failed
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package merr
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
rootCoordBits = (iota + 1) << 16
|
||||
dataCoordBits
|
||||
queryCoordBits
|
||||
dataNodeBits
|
||||
queryNodeBits
|
||||
indexNodeBits
|
||||
proxyBits
|
||||
standaloneBits
|
||||
embededBits
|
||||
|
||||
retriableFlag = 1 << 20
|
||||
RootReasonCodeMask = (1 << 16) - 1
|
||||
|
||||
CanceledCode int32 = 10000
|
||||
TimeoutCode int32 = 10001
|
||||
)
|
||||
|
||||
// Define leaf errors here,
|
||||
// WARN: take care to add new error,
|
||||
// check whehter you can use the erorrs below before adding a new one.
|
||||
// Name: Err + related prefix + error name
|
||||
var (
|
||||
// Service related
|
||||
ErrServiceNotReady = newMilvusError("service not ready", 1, true) // This indicates the service is still in init
|
||||
ErrServiceUnavailable = newMilvusError("service unavailable", 2, true)
|
||||
ErrServiceMemoryLimitExceeded = newMilvusError("memory limit exceeded", 3, false)
|
||||
ErrServiceRequestLimitExceeded = newMilvusError("request limit exceeded", 4, true)
|
||||
ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus
|
||||
ErrCrossClusterRouting = newMilvusError("cross cluster routing", 6, false)
|
||||
|
||||
// Collection related
|
||||
ErrCollectionNotFound = newMilvusError("collection not found", 100, false)
|
||||
ErrCollectionNotLoaded = newMilvusError("collection not loaded", 101, false)
|
||||
ErrCollectionNumLimitExceeded = newMilvusError("exceeded the limit number of collections", 102, false)
|
||||
ErrCollectionNotFullyLoaded = newMilvusError("collection not fully loaded", 103, true)
|
||||
|
||||
// Partition related
|
||||
ErrPartitionNotFound = newMilvusError("partition not found", 202, false)
|
||||
ErrPartitionNotLoaded = newMilvusError("partition not loaded", 203, false)
|
||||
ErrPartitionNotFullyLoaded = newMilvusError("collection not fully loaded", 103, true)
|
||||
|
||||
// ResourceGroup related
|
||||
ErrResourceGroupNotFound = newMilvusError("resource group not found", 300, false)
|
||||
|
||||
// Replica related
|
||||
ErrReplicaNotFound = newMilvusError("replica not found", 400, false)
|
||||
ErrReplicaNotAvailable = newMilvusError("replica not available", 401, false)
|
||||
|
||||
// Channel related
|
||||
ErrChannelNotFound = newMilvusError("channel not found", 500, false)
|
||||
ErrChannelLack = newMilvusError("channel lacks", 501, false)
|
||||
ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false)
|
||||
ErrChannelNotAvailable = newMilvusError("channel not available", 503, false)
|
||||
|
||||
// Segment related
|
||||
ErrSegmentNotFound = newMilvusError("segment not found", 600, false)
|
||||
ErrSegmentNotLoaded = newMilvusError("segment not loaded", 601, false)
|
||||
ErrSegmentLack = newMilvusError("segment lacks", 602, false)
|
||||
ErrSegmentReduplicate = newMilvusError("segment reduplicates", 603, false)
|
||||
|
||||
// Index related
|
||||
ErrIndexNotFound = newMilvusError("index not found", 700, false)
|
||||
|
||||
// Database related
|
||||
ErrDatabaseNotfound = newMilvusError("database not found", 800, false)
|
||||
ErrDatabaseNumLimitExceeded = newMilvusError("exceeded the limit number of database", 801, false)
|
||||
ErrInvalidedDatabaseName = newMilvusError("invalided database name", 802, false)
|
||||
|
||||
// Node related
|
||||
ErrNodeNotFound = newMilvusError("node not found", 901, false)
|
||||
ErrNodeOffline = newMilvusError("node offline", 902, false)
|
||||
ErrNodeLack = newMilvusError("node lacks", 903, false)
|
||||
ErrNodeNotMatch = newMilvusError("node not match", 904, false)
|
||||
ErrNodeNotAvailable = newMilvusError("node not available", 905, false)
|
||||
|
||||
// IO related
|
||||
ErrIoKeyNotFound = newMilvusError("key not found", 1000, false)
|
||||
ErrIoFailed = newMilvusError("IO failed", 1001, false)
|
||||
|
||||
// Parameter related
|
||||
ErrParameterInvalid = newMilvusError("invalid parameter", 1100, false)
|
||||
|
||||
// Metrics related
|
||||
ErrMetricNotFound = newMilvusError("metric not found", 1200, false)
|
||||
|
||||
// Topic related
|
||||
ErrTopicNotFound = newMilvusError("topic not found", 1300, false)
|
||||
ErrTopicNotEmpty = newMilvusError("topic not empty", 1301, false)
|
||||
|
||||
// shard delegator related
|
||||
ErrShardDelegatorNotFound = newMilvusError("shard delegator not found", 1500, false)
|
||||
ErrShardDelegatorAccessFailed = newMilvusError("fail to access shard delegator", 1501, true)
|
||||
ErrShardDelegatorSearchFailed = newMilvusError("fail to search on all shard leaders", 1502, true)
|
||||
ErrShardDelegatorQueryFailed = newMilvusError("fail to query on all shard leaders", 1503, true)
|
||||
ErrShardDelegatorStatisticFailed = newMilvusError("get statistics on all shard leaders", 1504, true)
|
||||
|
||||
// field related
|
||||
ErrFieldNotFound = newMilvusError("field not found", 1700, false)
|
||||
|
||||
// high-level restful api related
|
||||
ErrNeedAuthenticate = newMilvusError("user hasn't authenticate", 1800, false)
|
||||
ErrIncorrectParameterFormat = newMilvusError("can only accept json format request", 1801, false)
|
||||
ErrMissingRequiredParameters = newMilvusError("missing required parameters", 1802, false)
|
||||
ErrMarshalCollectionSchema = newMilvusError("fail to marshal collection schema", 1803, false)
|
||||
ErrInvalidInsertData = newMilvusError("fail to deal the insert data", 1804, false)
|
||||
ErrInvalidSearchResult = newMilvusError("fail to parse search result", 1805, false)
|
||||
ErrCheckPrimaryKey = newMilvusError("please check the primary key and its' type can only in [int, string]", 1806, false)
|
||||
|
||||
// Do NOT export this,
|
||||
// never allow programmer using this, keep only for converting unknown error to milvusError
|
||||
errUnexpected = newMilvusError("unexpected error", (1<<16)-1, false)
|
||||
)
|
||||
|
||||
func maskComponentBits(code int32) int32 {
|
||||
return code | proxyBits
|
||||
}
|
||||
|
||||
type milvusError struct {
|
||||
msg string
|
||||
errCode int32
|
||||
}
|
||||
|
||||
func newMilvusError(msg string, code int32, retriable bool) milvusError {
|
||||
if retriable {
|
||||
code |= retriableFlag
|
||||
}
|
||||
return milvusError{
|
||||
msg: msg,
|
||||
errCode: code,
|
||||
}
|
||||
}
|
||||
|
||||
func (e milvusError) code() int32 {
|
||||
return maskComponentBits(e.errCode)
|
||||
}
|
||||
|
||||
func (e milvusError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
// Code returns the error code of the given error,
|
||||
// WARN: DO NOT use this for now
|
||||
func Code(err error) int32 {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
cause := errors.Cause(err)
|
||||
switch cause := cause.(type) {
|
||||
case milvusError:
|
||||
return cause.code()
|
||||
|
||||
default:
|
||||
if errors.Is(cause, context.Canceled) {
|
||||
return CanceledCode
|
||||
} else if errors.Is(cause, context.DeadlineExceeded) {
|
||||
return TimeoutCode
|
||||
} else {
|
||||
return errUnexpected.code()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package merr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ErrSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *ErrSuite) SetupSuite() {
|
||||
}
|
||||
|
||||
func (s *ErrSuite) TestCode() {
|
||||
err := ErrCollectionNotFound
|
||||
errors.Wrap(err, "failed to get collection")
|
||||
s.ErrorIs(err, ErrCollectionNotFound)
|
||||
s.Equal(Code(ErrCollectionNotFound), Code(err))
|
||||
s.Equal(ErrCollectionNotFound.msg, err.Error())
|
||||
s.Equal(TimeoutCode, Code(context.DeadlineExceeded))
|
||||
s.Equal(CanceledCode, Code(context.Canceled))
|
||||
s.Equal(errUnexpected.code(), Code(errors.New("unexpected")))
|
||||
s.Equal(int32(0), Code(nil))
|
||||
}
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
suite.Run(t, new(ErrSuite))
|
||||
}
|
Loading…
Reference in New Issue