support high-level RESTFUL API, listen on the same port as grpc. (#25108)

Signed-off-by: PowderLi <min.li@zilliz.com>
pull/26133/head
PowderLi 2023-08-08 10:15:07 +08:00 committed by GitHub
parent 28cea5e763
commit a7eecb1be0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 4229 additions and 30 deletions

View File

@ -77,6 +77,8 @@ issues:
- G304
# Deferring unsafe method like *os.File Close
- G307
# TLS MinVersion too low
- G402
# Use of weak random number generator math/rand
- G404
# Maximum issues count per one linter. Set to 0 to disable. Default is 50.

17
go.mod
View File

@ -4,7 +4,7 @@ go 1.18
require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/aliyun/credentials-go v1.2.6
github.com/aliyun/credentials-go v1.2.7
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210826220005-b48c857c3a0e
github.com/antonmedv/expr v1.8.9
github.com/apache/arrow/go/v8 v8.0.0-20220322092137-778b1772fd20
@ -50,6 +50,9 @@ require (
golang.org/x/sync v0.1.0
golang.org/x/text v0.9.0
google.golang.org/grpc v1.54.0
google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f
google.golang.org/protobuf v1.30.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gorm.io/driver/mysql v1.3.5
gorm.io/gorm v1.23.8
stathat.com/c/consistent v1.0.0
@ -169,7 +172,7 @@ require (
github.com/rs/xid v1.2.1 // indirect
github.com/shirou/gopsutil/v3 v3.22.9
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/soheilhy/cmux v0.1.5
github.com/spaolacci/murmur3 v1.1.0
github.com/spf13/afero v1.6.0 // indirect
github.com/spf13/cast v1.3.1
@ -209,15 +212,19 @@ require (
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
gonum.org/v1/gonum v0.9.3 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f
google.golang.org/protobuf v1.30.0
gopkg.in/ini.v1 v1.62.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect
)
require github.com/tidwall/gjson v1.14.4
require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
)
replace (
github.com/apache/pulsar-client-go => github.com/milvus-io/pulsar-client-go v0.6.10
github.com/bketelsen/crypt => github.com/bketelsen/crypt v0.0.4 // Fix security alert for core-os/etcd

10
go.sum
View File

@ -82,8 +82,8 @@ github.com/alibabacloud-go/debug v0.0.0-20190504072949-9472017b5c68 h1:NqugFkGxx
github.com/alibabacloud-go/debug v0.0.0-20190504072949-9472017b5c68/go.mod h1:6pb/Qy8c+lqua8cFpEy7g39NRRqOWc3rOwAy8m5Y2BY=
github.com/alibabacloud-go/tea v1.1.8 h1:vFF0707fqjGiQTxrtMnIXRjOCvQXf49CuDVRtTopmwU=
github.com/alibabacloud-go/tea v1.1.8/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4=
github.com/aliyun/credentials-go v1.2.6 h1:dSMxpj4uXZj0MYOsEyljlssHzfdHw/M84iQ5QKF0Uxg=
github.com/aliyun/credentials-go v1.2.6/go.mod h1:/KowD1cfGSLrLsH28Jr8W+xwoId0ywIy5lNzDz6O1vw=
github.com/aliyun/credentials-go v1.2.7 h1:gLtFylxLZ1TWi1pStIt1O6a53GFU1zkNwjtJir2B4ow=
github.com/aliyun/credentials-go v1.2.7/go.mod h1:/KowD1cfGSLrLsH28Jr8W+xwoId0ywIy5lNzDz6O1vw=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
@ -812,6 +812,12 @@ github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw=
github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk=
github.com/tklauser/numcpus v0.4.0 h1:E53Dm1HjH1/R2/aoCtXtPgzmElmn51aOkhCFSuZq//o=

View File

@ -0,0 +1,57 @@
package httpserver
const (
ContextUsername = "username"
VectorCollectionsPath = "/vector/collections"
VectorCollectionsCreatePath = "/vector/collections/create"
VectorCollectionsDescribePath = "/vector/collections/describe"
VectorCollectionsDropPath = "/vector/collections/drop"
VectorInsertPath = "/vector/insert"
VectorSearchPath = "/vector/search"
VectorGetPath = "/vector/get"
VectorQueryPath = "/vector/query"
VectorDeletePath = "/vector/delete"
ShardNumDefault = 1
EnableDynamic = true
EnableAutoID = true
DisableAutoID = false
HTTPCollectionName = "collectionName"
HTTPDbName = "dbName"
DefaultDbName = "default"
DefaultIndexName = "vector_idx"
DefaultOutputFields = "*"
HTTPReturnCode = "code"
HTTPReturnMessage = "message"
HTTPReturnData = "data"
HTTPReturnFieldName = "name"
HTTPReturnFieldType = "type"
HTTPReturnFieldPrimaryKey = "primaryKey"
HTTPReturnFieldAutoID = "autoId"
HTTPReturnDescription = "description"
HTTPReturnIndexName = "indexName"
HTTPReturnIndexField = "fieldName"
HTTPReturnIndexMetricsType = "metricType"
HTTPReturnDistance = "distance"
DefaultMetricType = "L2"
DefaultPrimaryFieldName = "id"
DefaultVectorFieldName = "vector"
Dim = "dim"
)
const (
ParamAnnsField = "anns_field"
Params = "params"
ParamRoundDecimal = "round_decimal"
ParamOffset = "offset"
ParamLimit = "limit"
BoundedTimestamp = 2
)

View File

@ -0,0 +1,625 @@
package httpserver
import (
"encoding/json"
"net/http"
"strconv"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/cockroachdb/errors"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
func checkAuthorization(c *gin.Context, req interface{}) error {
if proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
username, ok := c.Get(ContextUsername)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
return merr.ErrNeedAuthenticate
}
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), req)
if authErr != nil {
c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
return authErr
}
}
return nil
}
func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool {
if dbName == DefaultDbName {
return true
}
response, err := h.proxy.ListDatabases(c, &milvuspb.ListDatabasesRequest{})
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
return false
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
return false
}
for _, db := range response.DbNames {
if db == dbName {
return true
}
}
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrDatabaseNotfound), HTTPReturnMessage: merr.ErrDatabaseNotfound.Error()})
return false
}
func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) {
req := milvuspb.DescribeCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
}
if needAuth {
if err := checkAuthorization(c, &req); err != nil {
return nil, err
}
}
response, err := h.proxy.DescribeCollection(c, &req)
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
return nil, err
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
return nil, errors.New(response.Status.Reason)
}
primaryField, ok := getPrimaryField(response.Schema)
if ok && primaryField.AutoID && !response.Schema.AutoID {
log.Warn("primary filed autoID VS schema autoID", zap.String("collectionName", collectionName), zap.Bool("primary Field", primaryField.AutoID), zap.Bool("schema", response.Schema.AutoID))
response.Schema.AutoID = EnableAutoID
}
return response, nil
}
func (h *Handlers) hasCollection(c *gin.Context, dbName string, collectionName string) (bool, error) {
req := milvuspb.HasCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
}
response, err := h.proxy.HasCollection(c, &req)
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
return false, err
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
return false, errors.New(response.Status.Reason)
} else {
return response.Value, nil
}
}
func (h *Handlers) RegisterRoutesToV1(router gin.IRouter) {
router.GET(VectorCollectionsPath, h.listCollections)
router.POST(VectorCollectionsCreatePath, h.createCollection)
router.GET(VectorCollectionsDescribePath, h.getCollectionDetails)
router.POST(VectorCollectionsDropPath, h.dropCollection)
router.POST(VectorQueryPath, h.query)
router.POST(VectorGetPath, h.get)
router.POST(VectorDeletePath, h.delete)
router.POST(VectorInsertPath, h.insert)
router.POST(VectorSearchPath, h.search)
}
func (h *Handlers) listCollections(c *gin.Context) {
dbName := c.DefaultQuery(HTTPDbName, DefaultDbName)
req := milvuspb.ShowCollectionsRequest{
DbName: dbName,
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, dbName) {
return
}
response, err := h.proxy.ShowCollections(c, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
} else {
var collections []string
if response.CollectionNames != nil {
collections = response.CollectionNames
} else {
collections = []string{}
}
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: collections})
}
}
func (h *Handlers) createCollection(c *gin.Context) {
httpReq := CreateCollectionReq{
DbName: DefaultDbName,
MetricType: DefaultMetricType,
PrimaryField: DefaultPrimaryFieldName,
VectorField: DefaultVectorFieldName,
}
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
log.Warn("high level restful api, the parameter of create collection is incorrect", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
return
}
if httpReq.CollectionName == "" || httpReq.Dimension == 0 {
log.Warn("high level restful api, create collection require parameters: [collectionName, dimension], but miss", zap.Any("request", httpReq))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
schema, err := proto.Marshal(&schemapb.CollectionSchema{
Name: httpReq.CollectionName,
Description: httpReq.Description,
AutoID: EnableAutoID,
Fields: []*schemapb.FieldSchema{
{
FieldID: common.StartOfUserFieldID,
Name: httpReq.PrimaryField,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
AutoID: EnableAutoID,
}, {
FieldID: common.StartOfUserFieldID + 1,
Name: httpReq.VectorField,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: Dim,
Value: strconv.FormatInt(int64(httpReq.Dimension), 10),
},
},
AutoID: DisableAutoID,
},
},
EnableDynamicField: EnableDynamic,
})
if err != nil {
log.Warn("high level restful api, marshal collection schema fail", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error()})
return
}
req := milvuspb.CreateCollectionRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
Schema: schema,
ShardsNum: ShardNumDefault,
ConsistencyLevel: commonpb.ConsistencyLevel_Bounded,
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, req.DbName) {
return
}
response, err := h.proxy.CreateCollection(c, &req)
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
return
} else if response.ErrorCode != commonpb.ErrorCode_Success {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
return
}
response, err = h.proxy.CreateIndex(c, &milvuspb.CreateIndexRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
FieldName: httpReq.VectorField,
IndexName: DefaultIndexName,
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}},
})
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
return
} else if response.ErrorCode != commonpb.ErrorCode_Success {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
return
}
response, err = h.proxy.LoadCollection(c, &milvuspb.LoadCollectionRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
})
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
return
} else if response.ErrorCode != commonpb.ErrorCode_Success {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
return
}
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
}
func (h *Handlers) getCollectionDetails(c *gin.Context) {
collectionName := c.Query(HTTPCollectionName)
if collectionName == "" {
log.Warn("high level restful api, desc collection require parameter: [collectionName], but miss")
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
dbName := c.DefaultQuery(HTTPDbName, DefaultDbName)
if !h.checkDatabase(c, dbName) {
return
}
coll, err := h.describeCollection(c, dbName, collectionName, true)
if err != nil {
return
}
stateResp, stateErr := h.proxy.GetLoadState(c, &milvuspb.GetLoadStateRequest{
DbName: dbName,
CollectionName: collectionName,
})
collLoadState := ""
if stateErr != nil {
log.Warn("get collection load state fail", zap.String("collection", collectionName), zap.String("err", stateErr.Error()))
} else if stateResp.Status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("get collection load state fail", zap.String("collection", collectionName), zap.String("err", stateResp.Status.Reason))
} else {
collLoadState = stateResp.State.String()
}
vectorField := ""
for _, field := range coll.Schema.Fields {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
vectorField = field.Name
break
}
}
indexResp, indexErr := h.proxy.DescribeIndex(c, &milvuspb.DescribeIndexRequest{
DbName: dbName,
CollectionName: collectionName,
FieldName: vectorField,
})
var indexDesc []gin.H
if indexErr != nil {
indexDesc = []gin.H{}
log.Warn("get indexes description fail", zap.String("collection", collectionName), zap.String("vectorField", vectorField), zap.String("err", indexErr.Error()))
} else if indexResp.Status.ErrorCode != commonpb.ErrorCode_Success {
indexDesc = []gin.H{}
log.Warn("get indexes description fail", zap.String("collection", collectionName), zap.String("vectorField", vectorField), zap.String("err", indexResp.Status.Reason))
} else {
indexDesc = printIndexes(indexResp.IndexDescriptions)
}
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{
HTTPCollectionName: coll.CollectionName,
HTTPReturnDescription: coll.Schema.Description,
"fields": printFields(coll.Schema.Fields),
"indexes": indexDesc,
"load": collLoadState,
"shardsNum": coll.ShardsNum,
"enableDynamic": coll.Schema.EnableDynamicField,
}})
}
func (h *Handlers) dropCollection(c *gin.Context) {
httpReq := DropCollectionReq{
DbName: DefaultDbName,
}
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
log.Warn("high level restful api, the parameter of drop collection is incorrect", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
return
}
if httpReq.CollectionName == "" {
log.Warn("high level restful api, drop collection require parameter: [collectionName], but miss")
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
req := milvuspb.DropCollectionRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, req.DbName) {
return
}
has, err := h.hasCollection(c, httpReq.DbName, httpReq.CollectionName)
if err != nil {
return
}
if !has {
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()})
return
}
response, err := h.proxy.DropCollection(c, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
} else if response.ErrorCode != commonpb.ErrorCode_Success {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
} else {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
}
}
func (h *Handlers) query(c *gin.Context) {
httpReq := QueryReq{
DbName: DefaultDbName,
Limit: 100,
OutputFields: []string{DefaultOutputFields},
}
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
log.Warn("high level restful api, the parameter of query is incorrect", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
return
}
if httpReq.CollectionName == "" || httpReq.Filter == "" {
log.Warn("high level restful api, query require parameter: [collectionName, filter], but miss")
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
req := milvuspb.QueryRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
Expr: httpReq.Filter,
OutputFields: httpReq.OutputFields,
GuaranteeTimestamp: BoundedTimestamp,
QueryParams: []*commonpb.KeyValuePair{},
}
if httpReq.Offset > 0 {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
}
if httpReq.Limit > 0 {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, req.DbName) {
return
}
response, err := h.proxy.Query(c, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
} else {
outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil)
if err != nil {
log.Warn("high level restful api, fail to deal with query result", zap.Any("response", response), zap.Error(err))
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
} else {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
}
}
}
func (h *Handlers) get(c *gin.Context) {
httpReq := GetReq{
DbName: DefaultDbName,
OutputFields: []string{DefaultOutputFields},
}
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
log.Warn("high level restful api, the parameter of get is incorrect", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
return
}
if httpReq.CollectionName == "" || httpReq.ID == nil {
log.Warn("high level restful api, get require parameter: [collectionName, id], but miss")
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
req := milvuspb.QueryRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
OutputFields: httpReq.OutputFields,
GuaranteeTimestamp: BoundedTimestamp,
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, req.DbName) {
return
}
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
if err != nil || coll == nil {
return
}
body, _ := c.Get(gin.BodyBytesKey)
filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName))
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
return
}
req.Expr = filter
response, err := h.proxy.Query(c, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
} else {
outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil)
if err != nil {
log.Warn("high level restful api, fail to deal with get result", zap.Any("response", response), zap.Error(err))
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
} else {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
log.Error("get resultIS: ", zap.Any("res", outputData))
}
}
}
func (h *Handlers) delete(c *gin.Context) {
httpReq := DeleteReq{
DbName: DefaultDbName,
}
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
log.Warn("high level restful api, the parameter of delete is incorrect", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
return
}
if httpReq.CollectionName == "" || httpReq.ID == nil {
log.Warn("high level restful api, delete require parameter: [collectionName, id], but miss")
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
req := milvuspb.DeleteRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, req.DbName) {
return
}
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
if err != nil || coll == nil {
return
}
body, _ := c.Get(gin.BodyBytesKey)
filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName))
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
return
}
req.Expr = filter
response, err := h.proxy.Delete(c, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
} else {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
}
}
func (h *Handlers) insert(c *gin.Context) {
httpReq := InsertReq{
DbName: DefaultDbName,
}
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
singleInsertReq := SingleInsertReq{
DbName: DefaultDbName,
}
if err = c.ShouldBindBodyWith(&singleInsertReq, binding.JSON); err != nil {
log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
return
}
httpReq.DbName = singleInsertReq.DbName
httpReq.CollectionName = singleInsertReq.CollectionName
httpReq.Data = []map[string]interface{}{singleInsertReq.Data}
}
if httpReq.CollectionName == "" || httpReq.Data == nil {
log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss")
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
req := milvuspb.InsertRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
PartitionName: "_default",
NumRows: uint32(len(httpReq.Data)),
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, req.DbName) {
return
}
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
if err != nil || coll == nil {
return
}
body, _ := c.Get(gin.BodyBytesKey)
err = checkAndSetData(string(body.([]byte)), coll, &httpReq)
if err != nil {
log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()})
return
}
req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema)
if err != nil {
log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()})
return
}
response, err := h.proxy.Insert(c, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
} else {
switch response.IDs.GetIdField().(type) {
case *schemapb.IDs_IntId:
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}})
case *schemapb.IDs_StrId:
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}})
default:
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
}
}
}
func (h *Handlers) search(c *gin.Context) {
httpReq := SearchReq{
DbName: DefaultDbName,
Limit: 100,
}
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
log.Warn("high level restful api, the parameter of search is incorrect", zap.Any("request", httpReq), zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
return
}
if httpReq.CollectionName == "" || httpReq.Vector == nil {
log.Warn("high level restful api, search require parameter: [collectionName, vector], but miss")
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
return
}
params := map[string]interface{}{ //auto generated mapping
"level": int(commonpb.ConsistencyLevel_Bounded),
}
bs, _ := json.Marshal(params)
searchParams := []*commonpb.KeyValuePair{
{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
{Key: Params, Value: string(bs)},
{Key: ParamRoundDecimal, Value: "-1"},
{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)},
}
req := milvuspb.SearchRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
SearchParams: searchParams,
GuaranteeTimestamp: BoundedTimestamp,
Nq: int64(1),
}
if err := checkAuthorization(c, &req); err != nil {
return
}
if !h.checkDatabase(c, req.DbName) {
return
}
response, err := h.proxy.Search(c, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
} else {
if response.Results.TopK == int64(0) {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}})
} else {
outputData, err := buildQueryResp(response.Results.TopK, response.Results.OutputFields, response.Results.FieldsData, response.Results.Ids, response.Results.Scores)
if err != nil {
log.Warn("high level restful api, fail to deal with search result", zap.Any("result", response.Results), zap.Error(err))
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
} else {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,60 @@
package httpserver
type CreateCollectionReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
Dimension int32 `json:"dimension" validate:"required"`
Description string `json:"description"`
MetricType string `json:"metricType"`
PrimaryField string `json:"primaryField"`
VectorField string `json:"vectorField"`
}
type DropCollectionReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
}
type QueryReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
OutputFields []string `json:"outputFields"`
Filter string `json:"filter" validate:"required"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
type GetReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
OutputFields []string `json:"outputFields"`
ID interface{} `json:"id" validate:"required"`
}
type DeleteReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
ID interface{} `json:"id" validate:"required"`
}
type InsertReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
Data []map[string]interface{} `json:"data" validate:"required"`
}
type SingleInsertReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
Data map[string]interface{} `json:"data" validate:"required"`
}
type SearchReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" validate:"required"`
Filter string `json:"filter"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
Vector []float32 `json:"vector"`
}

View File

@ -0,0 +1,871 @@
package httpserver
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"reflect"
"strconv"
"strings"
"github.com/milvus-io/milvus/pkg/util/parameterutil.go"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/cockroachdb/errors"
"github.com/gin-gonic/gin"
"github.com/golang/protobuf/proto"
"github.com/spf13/cast"
"github.com/tidwall/gjson"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
)
func ParseUsernamePassword(c *gin.Context) (string, string, bool) {
username, password, ok := c.Request.BasicAuth()
if !ok {
auth := c.Request.Header.Get("Authorization")
if auth != "" {
token := strings.TrimPrefix(auth, "Bearer ")
if token != auth {
i := strings.IndexAny(token, ":")
if i != -1 {
username = token[:i]
password = token[i+1:]
}
}
}
} else {
c.Header("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
}
return username, password, username != "" && password != ""
}
// find the primary field of collection
func getPrimaryField(schema *schemapb.CollectionSchema) (*schemapb.FieldSchema, bool) {
for _, field := range schema.Fields {
if field.IsPrimaryKey {
return field, true
}
}
return nil, false
}
func joinArray(data interface{}) string {
var buffer bytes.Buffer
arr := reflect.ValueOf(data)
for i := 0; i < arr.Len(); i++ {
if i > 0 {
buffer.WriteString(",")
}
buffer.WriteString(fmt.Sprintf("%v", arr.Index(i)))
}
return buffer.String()
}
func convertRange(field *schemapb.FieldSchema, result gjson.Result) (string, error) {
var resultStr string
fieldType := field.DataType
if fieldType == schemapb.DataType_Int64 {
var dataArray []int64
for _, data := range result.Array() {
if data.Type == gjson.String {
value, err := cast.ToInt64E(data.Str)
if err != nil {
return "", err
}
dataArray = append(dataArray, value)
} else {
value, err := cast.ToInt64E(data.Raw)
if err != nil {
return "", err
}
dataArray = append(dataArray, value)
}
}
resultStr = joinArray(dataArray)
} else if fieldType == schemapb.DataType_VarChar {
var dataArray []string
for _, data := range result.Array() {
value, err := cast.ToStringE(data.Str)
if err != nil {
return "", err
}
dataArray = append(dataArray, value)
}
resultStr = joinArray(dataArray)
}
return resultStr, nil
}
// generate the expression: $primaryFieldName in [1,2,3]
func checkGetPrimaryKey(coll *schemapb.CollectionSchema, idResult gjson.Result) (string, error) {
primaryField, ok := getPrimaryField(coll)
if !ok {
return "", errors.New("fail to find primary key from collection description")
}
resultStr, err := convertRange(primaryField, idResult)
if err != nil {
return "", err
}
filter := primaryField.Name + " in [" + resultStr + "]"
return filter, nil
}
// --------------------- collection details --------------------- //
func printFields(fields []*schemapb.FieldSchema) []gin.H {
var res []gin.H
for _, field := range fields {
fieldDetail := gin.H{
HTTPReturnFieldName: field.Name,
HTTPReturnFieldPrimaryKey: field.IsPrimaryKey,
HTTPReturnFieldAutoID: field.AutoID,
HTTPReturnDescription: field.Description,
}
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
dim, _ := getDim(field)
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
} else if field.DataType == schemapb.DataType_VarChar {
maxLength, _ := parameterutil.GetMaxLength(field)
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
} else {
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
}
res = append(res, fieldDetail)
}
return res
}
func getMetricType(pairs []*commonpb.KeyValuePair) string {
metricType := DefaultMetricType
for _, pair := range pairs {
if pair.Key == common.MetricTypeKey {
metricType = pair.Value
break
}
}
return metricType
}
func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H {
var res []gin.H
for _, index := range indexes {
res = append(res, gin.H{
HTTPReturnIndexName: index.IndexName,
HTTPReturnIndexField: index.FieldName,
HTTPReturnIndexMetricsType: getMetricType(index.Params),
})
}
return res
}
// --------------------- insert param --------------------- //
func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionResponse, req *InsertReq) error {
var reallyDataArray []map[string]interface{}
dataResult := gjson.Get(body, "data")
dataResultArray := dataResult.Array()
if len(dataResultArray) == 0 {
return errors.New("data is required")
}
var fieldNames []string
for _, field := range collDescResp.Schema.Fields {
fieldNames = append(fieldNames, field.Name)
}
for _, data := range dataResultArray {
reallyData := map[string]interface{}{}
var vectorArray []float32
var binaryArray []byte
if data.Type == gjson.JSON {
for _, field := range collDescResp.Schema.Fields {
fieldType := field.DataType
fieldName := field.Name
dataString := gjson.Get(data.Raw, fieldName).String()
if field.IsPrimaryKey && collDescResp.Schema.AutoID {
if dataString != "" {
return fmt.Errorf("fieldName %s AutoId already open, not support insert data %s", fieldName, dataString)
}
continue
}
switch fieldType {
case schemapb.DataType_FloatVector:
for _, vector := range gjson.Get(data.Raw, fieldName).Array() {
vectorArray = append(vectorArray, cast.ToFloat32(vector.Num))
}
reallyData[fieldName] = vectorArray
case schemapb.DataType_BinaryVector:
for _, vector := range gjson.Get(data.Raw, fieldName).Array() {
binaryArray = append(binaryArray, cast.ToUint8(vector.Num))
}
reallyData[fieldName] = binaryArray
case schemapb.DataType_Bool:
result, err := cast.ToBoolE(dataString)
if err != nil {
return fmt.Errorf("dataString %s cast to bool error: %s", dataString, err.Error())
}
reallyData[fieldName] = result
case schemapb.DataType_Int8:
result, err := cast.ToInt8E(dataString)
if err != nil {
return fmt.Errorf("dataString %s cast to int8 error: %s", dataString, err.Error())
}
reallyData[fieldName] = result
case schemapb.DataType_Int16:
result, err := cast.ToInt16E(dataString)
if err != nil {
return fmt.Errorf("dataString %s cast to int16 error: %s", dataString, err.Error())
}
reallyData[fieldName] = result
case schemapb.DataType_Int32:
result, err := cast.ToInt32E(dataString)
if err != nil {
return fmt.Errorf("dataString %s cast to int32 error: %s", dataString, err.Error())
}
reallyData[fieldName] = result
case schemapb.DataType_Int64:
result, err := cast.ToInt64E(dataString)
if err != nil {
return fmt.Errorf("dataString %s cast to int64 error: %s", dataString, err.Error())
}
reallyData[fieldName] = result
case schemapb.DataType_JSON:
reallyData[fieldName] = []byte(dataString)
case schemapb.DataType_Float:
result, err := cast.ToFloat32E(dataString)
if err != nil {
return fmt.Errorf("dataString %s cast to float32 error: %s", dataString, err.Error())
}
reallyData[fieldName] = result
case schemapb.DataType_Double:
result, err := cast.ToFloat64E(dataString)
if err != nil {
return fmt.Errorf("dataString %s cast to float64 error: %s", dataString, err.Error())
}
reallyData[fieldName] = result
case schemapb.DataType_VarChar:
reallyData[fieldName] = dataString
case schemapb.DataType_String:
reallyData[fieldName] = dataString
default:
return fmt.Errorf("not support fieldName %s dataType %s", fieldName, fieldType)
}
}
// fill dynamic schema
if collDescResp.Schema.EnableDynamicField {
for mapKey, mapValue := range data.Map() {
if !containsString(fieldNames, mapKey) {
mapValueStr := mapValue.String()
if mapValue.Type == gjson.True || mapValue.Type == gjson.False {
reallyData[mapKey] = cast.ToBool(mapValueStr)
} else if mapValue.Type == gjson.String {
reallyData[mapKey] = mapValueStr
} else if mapValue.Type == gjson.Number {
if strings.Contains(mapValue.Raw, ".") {
reallyData[mapKey] = cast.ToFloat64(mapValue.Raw)
} else {
reallyData[mapKey] = cast.ToInt64(mapValueStr)
}
} else if mapValue.Type == gjson.JSON {
reallyData[mapKey] = mapValue.Value()
} else {
}
}
}
}
reallyDataArray = append(reallyDataArray, reallyData)
} else {
return fmt.Errorf("dataType %s not Json", data.Type)
}
}
req.Data = reallyDataArray
return nil
}
func containsString(arr []string, s string) bool {
for _, str := range arr {
if str == s {
return true
}
}
return false
}
func getDim(field *schemapb.FieldSchema) (int64, error) {
dimensionInSchema, err := funcutil.GetAttrByKeyFromRepeatedKV(common.DimKey, field.TypeParams)
if err != nil {
return 0, err
}
dim, err := strconv.Atoi(dimensionInSchema)
if err != nil {
return 0, err
}
return int64(dim), nil
}
func convertFloatVectorToArray(vector [][]float32, dim int64) ([]float32, error) {
floatArray := make([]float32, 0)
for _, arr := range vector {
if int64(len(arr)) != dim {
return nil, errors.New("vector length diff from dimension")
}
for i := int64(0); i < dim; i++ {
floatArray = append(floatArray, arr[i])
}
}
return floatArray, nil
}
func convertBinaryVectorToArray(vector [][]byte, dim int64) ([]byte, error) {
binaryArray := make([]byte, 0)
bytesLen := dim / 8
for _, arr := range vector {
if int64(len(arr)) != bytesLen {
return nil, errors.New("vector length diff from dimension")
}
for i := int64(0); i < bytesLen; i++ {
binaryArray = append(binaryArray, arr[i])
}
}
return binaryArray, nil
}
type fieldCandi struct {
name string
v reflect.Value
options map[string]string
}
func reflectValueCandi(v reflect.Value) (map[string]fieldCandi, error) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
result := make(map[string]fieldCandi)
switch v.Kind() {
case reflect.Map: // map[string]interface{}
iter := v.MapRange()
for iter.Next() {
key := iter.Key().String()
result[key] = fieldCandi{
name: key,
v: iter.Value(),
}
}
return result, nil
default:
return nil, fmt.Errorf("unsupport row type: %s", v.Kind().String())
}
}
func convertToIntArray(dataType schemapb.DataType, arr interface{}) []int32 {
var res []int32
switch dataType {
case schemapb.DataType_Int8:
for _, num := range arr.([]int8) {
res = append(res, int32(num))
}
case schemapb.DataType_Int16:
for _, num := range arr.([]int16) {
res = append(res, int32(num))
}
}
return res
}
func anyToColumns(rows []map[string]interface{}, sch *schemapb.CollectionSchema) ([]*schemapb.FieldData, error) {
rowsLen := len(rows)
if rowsLen == 0 {
return []*schemapb.FieldData{}, errors.New("0 length column")
}
isDynamic := sch.EnableDynamicField
var dim int64
nameColumns := make(map[string]interface{})
fieldData := make(map[string]*schemapb.FieldData)
for _, field := range sch.Fields {
// skip auto id pk field
if field.IsPrimaryKey && field.AutoID {
continue
}
var data interface{}
switch field.DataType {
case schemapb.DataType_Bool:
data = make([]bool, 0, rowsLen)
case schemapb.DataType_Int8:
data = make([]int8, 0, rowsLen)
case schemapb.DataType_Int16:
data = make([]int16, 0, rowsLen)
case schemapb.DataType_Int32:
data = make([]int32, 0, rowsLen)
case schemapb.DataType_Int64:
data = make([]int64, 0, rowsLen)
case schemapb.DataType_Float:
data = make([]float32, 0, rowsLen)
case schemapb.DataType_Double:
data = make([]float64, 0, rowsLen)
case schemapb.DataType_String:
data = make([]string, 0, rowsLen)
case schemapb.DataType_VarChar:
data = make([]string, 0, rowsLen)
case schemapb.DataType_JSON:
data = make([][]byte, 0, rowsLen)
case schemapb.DataType_FloatVector:
data = make([][]float32, 0, rowsLen)
dim, _ = getDim(field)
case schemapb.DataType_BinaryVector:
data = make([][]byte, 0, rowsLen)
dim, _ = getDim(field)
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
}
nameColumns[field.Name] = data
fieldData[field.Name] = &schemapb.FieldData{
Type: field.DataType,
FieldName: field.Name,
FieldId: field.FieldID,
IsDynamic: field.IsDynamic,
}
}
if dim == 0 {
return nil, errors.New("cannot find dimension")
}
dynamicCol := make([][]byte, 0, rowsLen)
for _, row := range rows {
// collection schema name need not be same, since receiver could have other names
v := reflect.ValueOf(row)
set, err := reflectValueCandi(v)
if err != nil {
return nil, err
}
for idx, field := range sch.Fields {
// skip auto id pk field
if field.IsPrimaryKey && field.AutoID {
// remove pk field from candidates set, avoid adding it into dynamic column
delete(set, field.Name)
continue
}
candi, ok := set[field.Name]
if !ok {
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
}
switch field.DataType {
case schemapb.DataType_Bool:
nameColumns[field.Name] = append(nameColumns[field.Name].([]bool), candi.v.Interface().(bool))
case schemapb.DataType_Int8:
nameColumns[field.Name] = append(nameColumns[field.Name].([]int8), candi.v.Interface().(int8))
case schemapb.DataType_Int16:
nameColumns[field.Name] = append(nameColumns[field.Name].([]int16), candi.v.Interface().(int16))
case schemapb.DataType_Int32:
nameColumns[field.Name] = append(nameColumns[field.Name].([]int32), candi.v.Interface().(int32))
case schemapb.DataType_Int64:
nameColumns[field.Name] = append(nameColumns[field.Name].([]int64), candi.v.Interface().(int64))
case schemapb.DataType_Float:
nameColumns[field.Name] = append(nameColumns[field.Name].([]float32), candi.v.Interface().(float32))
case schemapb.DataType_Double:
nameColumns[field.Name] = append(nameColumns[field.Name].([]float64), candi.v.Interface().(float64))
case schemapb.DataType_String:
nameColumns[field.Name] = append(nameColumns[field.Name].([]string), candi.v.Interface().(string))
case schemapb.DataType_VarChar:
nameColumns[field.Name] = append(nameColumns[field.Name].([]string), candi.v.Interface().(string))
case schemapb.DataType_JSON:
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
case schemapb.DataType_FloatVector:
nameColumns[field.Name] = append(nameColumns[field.Name].([][]float32), candi.v.Interface().([]float32))
case schemapb.DataType_BinaryVector:
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
}
delete(set, field.Name)
}
if isDynamic {
m := make(map[string]interface{})
for name, candi := range set {
m[name] = candi.v.Interface()
}
bs, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("failed to marshal dynamic field %w", err)
}
dynamicCol = append(dynamicCol, bs)
if err != nil {
return nil, fmt.Errorf("failed to append value to dynamic field %w", err)
}
}
}
columns := make([]*schemapb.FieldData, 0, len(nameColumns))
for name, column := range nameColumns {
colData := fieldData[name]
switch colData.Type {
case schemapb.DataType_Bool:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: column.([]bool),
},
},
},
}
case schemapb.DataType_Int8:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: convertToIntArray(colData.Type, column),
},
},
},
}
case schemapb.DataType_Int16:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: convertToIntArray(colData.Type, column),
},
},
},
}
case schemapb.DataType_Int32:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: column.([]int32),
},
},
},
}
case schemapb.DataType_Int64:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: column.([]int64),
},
},
},
}
case schemapb.DataType_Float:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: column.([]float32),
},
},
},
}
case schemapb.DataType_Double:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: column.([]float64),
},
},
},
}
case schemapb.DataType_String:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: column.([]string),
},
},
},
}
case schemapb.DataType_VarChar:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: column.([]string),
},
},
},
}
case schemapb.DataType_JSON:
colData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BytesData{
BytesData: &schemapb.BytesArray{
Data: column.([][]byte),
},
},
},
}
case schemapb.DataType_FloatVector:
arr, err := convertFloatVectorToArray(column.([][]float32), dim)
if err != nil {
return nil, err
}
colData.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: arr,
},
},
},
}
case schemapb.DataType_BinaryVector:
arr, err := convertBinaryVectorToArray(column.([][]byte), dim)
if err != nil {
return nil, err
}
colData.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: arr,
},
},
}
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", colData.Type, name)
}
columns = append(columns, colData)
}
if isDynamic {
columns = append(columns, &schemapb.FieldData{
Type: schemapb.DataType_JSON,
FieldName: "",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: dynamicCol,
},
},
},
},
IsDynamic: true,
})
}
return columns, nil
}
// --------------------- search param --------------------- //
func serialize(fv []float32) []byte {
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
buf := make([]byte, 4)
for _, f := range fv {
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
data = append(data, buf...)
}
return data
}
func vector2PlaceholderGroupBytes(vectors []float32) []byte {
var placeHolderType commonpb.PlaceholderType
ph := &commonpb.PlaceholderValue{
Tag: "$0",
Values: make([][]byte, 0, len(vectors)),
}
if len(vectors) != 0 {
placeHolderType = commonpb.PlaceholderType_FloatVector
ph.Type = placeHolderType
ph.Values = append(ph.Values, serialize(vectors))
}
phg := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
ph,
},
}
bs, _ := proto.Marshal(phg)
return bs
}
// --------------------- get/query/search response --------------------- //
func genDynamicFields(fields []string, list []*schemapb.FieldData) []string {
nonDynamicFieldNames := make(map[string]struct{})
for _, field := range list {
if !field.IsDynamic {
nonDynamicFieldNames[field.FieldName] = struct{}{}
}
}
dynamicFields := []string{}
for _, fieldName := range fields {
if _, exist := nonDynamicFieldNames[fieldName]; !exist {
dynamicFields = append(dynamicFields, fieldName)
}
}
return dynamicFields
}
func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemapb.FieldData, ids *schemapb.IDs, scores []float32) ([]map[string]interface{}, error) {
var queryResp []map[string]interface{}
columnNum := len(fieldDataList)
if rowsNum == int64(0) {
if columnNum > 0 {
switch fieldDataList[0].Type {
case schemapb.DataType_Bool:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetBoolData().Data))
case schemapb.DataType_Int8:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetIntData().Data))
case schemapb.DataType_Int16:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetIntData().Data))
case schemapb.DataType_Int32:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetIntData().Data))
case schemapb.DataType_Int64:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetLongData().Data))
case schemapb.DataType_Float:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetFloatData().Data))
case schemapb.DataType_Double:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetDoubleData().Data))
case schemapb.DataType_String:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetStringData().Data))
case schemapb.DataType_VarChar:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetStringData().Data))
case schemapb.DataType_JSON:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetJsonData().Data))
case schemapb.DataType_Array:
rowsNum = int64(len(fieldDataList[0].GetScalars().GetArrayData().Data))
case schemapb.DataType_BinaryVector:
rowsNum = int64(len(fieldDataList[0].GetVectors().GetBinaryVector())*8) / fieldDataList[0].GetVectors().GetDim()
case schemapb.DataType_FloatVector:
rowsNum = int64(len(fieldDataList[0].GetVectors().GetFloatVector().Data)) / fieldDataList[0].GetVectors().GetDim()
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", fieldDataList[0].Type, fieldDataList[0].FieldName)
}
} else if ids != nil {
switch ids.IdField.(type) {
case *schemapb.IDs_IntId:
int64Pks := ids.GetIntId().GetData()
rowsNum = int64(len(int64Pks))
case *schemapb.IDs_StrId:
stringPks := ids.GetStrId().GetData()
rowsNum = int64(len(stringPks))
default:
return nil, fmt.Errorf("the type of primary key(id) is not supported, use other sdk please")
}
}
}
if rowsNum == int64(0) {
return []map[string]interface{}{}, nil
}
dynamicOutputFields := genDynamicFields(needFields, fieldDataList)
for i := int64(0); i < rowsNum; i++ {
row := map[string]interface{}{}
if columnNum > 0 {
for j := 0; j < columnNum; j++ {
switch fieldDataList[j].Type {
case schemapb.DataType_Bool:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetBoolData().Data[i]
case schemapb.DataType_Int8:
row[fieldDataList[j].FieldName] = int8(fieldDataList[j].GetScalars().GetIntData().Data[i])
case schemapb.DataType_Int16:
row[fieldDataList[j].FieldName] = int16(fieldDataList[j].GetScalars().GetIntData().Data[i])
case schemapb.DataType_Int32:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetIntData().Data[i]
case schemapb.DataType_Int64:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetLongData().Data[i]
case schemapb.DataType_Float:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetFloatData().Data[i]
case schemapb.DataType_Double:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetDoubleData().Data[i]
case schemapb.DataType_String:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetStringData().Data[i]
case schemapb.DataType_VarChar:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetStringData().Data[i]
case schemapb.DataType_BinaryVector:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetBinaryVector()[i*(fieldDataList[j].GetVectors().GetDim()/8) : (i+1)*(fieldDataList[j].GetVectors().GetDim()/8)]
case schemapb.DataType_FloatVector:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetFloatVector().Data[i*fieldDataList[j].GetVectors().GetDim() : (i+1)*fieldDataList[j].GetVectors().GetDim()]
case schemapb.DataType_Array:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetArrayData().Data[i]
case schemapb.DataType_JSON:
data, ok := fieldDataList[j].GetScalars().Data.(*schemapb.ScalarField_JsonData)
if ok && !fieldDataList[j].IsDynamic {
row[fieldDataList[j].FieldName] = data.JsonData.Data[i]
} else {
var dataMap map[string]interface{}
err := json.Unmarshal(fieldDataList[j].GetScalars().GetJsonData().Data[i], &dataMap)
if err != nil {
log.Error(fmt.Sprintf("[BuildQueryResp] Unmarshal error %s", err.Error()))
return nil, err
}
if containsString(dynamicOutputFields, fieldDataList[j].FieldName) {
for key, value := range dataMap {
row[key] = value
}
} else {
for _, dynamicField := range dynamicOutputFields {
if _, ok := dataMap[dynamicField]; ok {
row[dynamicField] = dataMap[dynamicField]
}
}
}
}
default:
row[fieldDataList[j].FieldName] = ""
}
}
}
if ids != nil {
switch ids.IdField.(type) {
case *schemapb.IDs_IntId:
int64Pks := ids.GetIntId().GetData()
row[DefaultPrimaryFieldName] = int64Pks[i]
case *schemapb.IDs_StrId:
stringPks := ids.GetStrId().GetData()
row[DefaultPrimaryFieldName] = stringPks[i]
default:
return nil, fmt.Errorf("the type of primary key(id) is not supported, use other sdk please")
}
}
if scores != nil && int64(len(scores)) > i {
row[HTTPReturnDistance] = scores[i]
}
queryResp = append(queryResp, row)
}
return queryResp, nil
}
// --------------------- error code --------------------- //
func code(code int32) int32 {
return code & merr.RootReasonCodeMask
}
func Code(err error) int32 {
return code(merr.Code(err))
}

View File

@ -0,0 +1,808 @@
package httpserver
import (
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
const (
FieldWordCount = "word_count"
FieldBookID = "book_id"
FieldBookIntro = "book_intro"
)
func generatePrimaryField(datatype schemapb.DataType) schemapb.FieldSchema {
return schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID,
Name: FieldBookID,
IsPrimaryKey: true,
Description: "",
DataType: datatype,
AutoID: false,
}
}
func generateIds(num int) *schemapb.IDs {
var intArray []int64
if num == 0 {
intArray = []int64{}
} else {
for i := int64(1); i < int64(num+1); i++ {
intArray = append(intArray, i)
}
}
return &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: intArray,
},
},
}
}
func generateVectorFieldSchema(useBinary bool) schemapb.FieldSchema {
if useBinary {
return schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID + 2,
Name: "field-binary",
IsPrimaryKey: false,
Description: "",
DataType: 100,
AutoID: false,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "8",
},
},
}
}
return schemapb.FieldSchema{
FieldID: common.StartOfUserFieldID + 2,
Name: FieldBookIntro,
IsPrimaryKey: false,
Description: "",
DataType: 101,
AutoID: false,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "2",
},
},
}
}
func generateCollectionSchema(useBinary bool) *schemapb.CollectionSchema {
primaryField := generatePrimaryField(schemapb.DataType_Int64)
vectorField := generateVectorFieldSchema(useBinary)
return &schemapb.CollectionSchema{
Name: DefaultCollectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
&primaryField, {
FieldID: common.StartOfUserFieldID + 1,
Name: FieldWordCount,
IsPrimaryKey: false,
Description: "",
DataType: 5,
AutoID: false,
}, &vectorField,
},
EnableDynamicField: true,
}
}
func generateIndexes() []*milvuspb.IndexDescription {
return []*milvuspb.IndexDescription{
{
IndexName: DefaultIndexName,
IndexID: 442051985533243300,
Params: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: DefaultMetricType,
},
{
Key: "index_type",
Value: "IVF_FLAT",
}, {
Key: Params,
Value: "{\"nlist\":1024}",
},
},
State: 3,
FieldName: FieldBookIntro,
},
}
}
func generateVectorFieldData(useBinary bool) schemapb.FieldData {
if useBinary {
return schemapb.FieldData{
Type: schemapb.DataType_BinaryVector,
FieldName: "field-binary",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 8,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: []byte{byte(0), byte(1), byte(2)},
},
},
},
IsDynamic: false,
}
}
return schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: FieldBookIntro,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 2,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{0.1, 0.11, 0.2, 0.22, 0.3, 0.33},
},
},
},
},
IsDynamic: false,
}
}
func generateFieldData() []*schemapb.FieldData {
fieldData1 := schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: FieldBookID,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
},
IsDynamic: false,
}
fieldData2 := schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: FieldWordCount,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1000, 2000, 3000},
},
},
},
},
IsDynamic: false,
}
fieldData3 := generateVectorFieldData(false)
return []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
}
func generateSearchResult() []map[string]interface{} {
row1 := map[string]interface{}{
DefaultPrimaryFieldName: int64(1),
FieldBookID: int64(1),
FieldWordCount: int64(1000),
FieldBookIntro: []float32{0.1, 0.11},
HTTPReturnDistance: float32(0.01),
}
row2 := map[string]interface{}{
DefaultPrimaryFieldName: int64(2),
FieldBookID: int64(2),
FieldWordCount: int64(2000),
FieldBookIntro: []float32{0.2, 0.22},
HTTPReturnDistance: float32(0.04),
}
row3 := map[string]interface{}{
DefaultPrimaryFieldName: int64(3),
FieldBookID: int64(3),
FieldWordCount: int64(3000),
FieldBookIntro: []float32{0.3, 0.33},
HTTPReturnDistance: float32(0.09),
}
return []map[string]interface{}{row1, row2, row3}
}
func generateQueryResult64(withDistance bool) []map[string]interface{} {
row1 := map[string]interface{}{
FieldBookID: float64(1),
FieldWordCount: float64(1000),
FieldBookIntro: []float64{0.1, 0.11},
}
row2 := map[string]interface{}{
FieldBookID: float64(2),
FieldWordCount: float64(2000),
FieldBookIntro: []float64{0.2, 0.22},
}
row3 := map[string]interface{}{
FieldBookID: float64(3),
FieldWordCount: float64(3000),
FieldBookIntro: []float64{0.3, 0.33},
}
if withDistance {
row1[HTTPReturnDistance] = float64(0.01)
row2[HTTPReturnDistance] = float64(0.04)
row3[HTTPReturnDistance] = float64(0.09)
}
return []map[string]interface{}{row1, row2, row3}
}
func TestPrintCollectionDetails(t *testing.T) {
coll := generateCollectionSchema(false)
indexes := generateIndexes()
assert.Equal(t, printFields(coll.Fields), []gin.H{
{
HTTPReturnFieldName: FieldBookID,
HTTPReturnFieldType: "Int64",
HTTPReturnFieldPrimaryKey: true,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: ""},
{
HTTPReturnFieldName: FieldWordCount,
HTTPReturnFieldType: "Int64",
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: ""},
{
HTTPReturnFieldName: FieldBookIntro,
HTTPReturnFieldType: "FloatVector(2)",
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: ""},
})
assert.Equal(t, printIndexes(indexes), []gin.H{
{
HTTPReturnIndexName: DefaultIndexName,
HTTPReturnIndexField: FieldBookIntro,
HTTPReturnIndexMetricsType: DefaultMetricType},
})
assert.Equal(t, getMetricType(indexes[0].Params), DefaultMetricType)
assert.Equal(t, getMetricType(nil), DefaultMetricType)
fields := []*schemapb.FieldSchema{}
for _, field := range newCollectionSchema(coll).Fields {
if field.DataType == schemapb.DataType_VarChar {
fields = append(fields, field)
}
}
assert.Equal(t, printFields(fields), []gin.H{
{
HTTPReturnFieldName: "field-varchar",
HTTPReturnFieldType: "VarChar(10)",
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: ""},
})
}
func TestPrimaryField(t *testing.T) {
coll := generateCollectionSchema(false)
primaryField := generatePrimaryField(schemapb.DataType_Int64)
field, ok := getPrimaryField(coll)
assert.Equal(t, ok, true)
assert.Equal(t, *field, primaryField)
assert.Equal(t, joinArray([]int64{1, 2, 3}), "1,2,3")
assert.Equal(t, joinArray([]string{"1", "2", "3"}), "1,2,3")
jsonStr := "{\"id\": [1, 2, 3]}"
idStr := gjson.Get(jsonStr, "id")
rangeStr, err := convertRange(&primaryField, idStr)
assert.Equal(t, err, nil)
assert.Equal(t, rangeStr, "1,2,3")
filter, err := checkGetPrimaryKey(coll, idStr)
assert.Equal(t, err, nil)
assert.Equal(t, filter, "book_id in [1,2,3]")
primaryField = generatePrimaryField(schemapb.DataType_VarChar)
jsonStr = "{\"id\": [\"1\", \"2\", \"3\"]}"
idStr = gjson.Get(jsonStr, "id")
rangeStr, err = convertRange(&primaryField, idStr)
assert.Equal(t, err, nil)
assert.Equal(t, rangeStr, "1,2,3")
filter, err = checkGetPrimaryKey(coll, idStr)
assert.Equal(t, err, nil)
assert.Equal(t, filter, "book_id in [1,2,3]")
}
func TestInsertWithDynamicFields(t *testing.T) {
body := "{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2}}"
req := InsertReq{}
coll := generateCollectionSchema(false)
err := checkAndSetData(body, &milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Schema: coll,
}, &req)
assert.Equal(t, err, nil)
assert.Equal(t, req.Data[0]["id"], int64(0))
assert.Equal(t, req.Data[0]["book_id"], int64(1))
assert.Equal(t, req.Data[0]["word_count"], int64(2))
fieldsData, err := anyToColumns(req.Data, coll)
assert.Equal(t, err, nil)
assert.Equal(t, fieldsData[len(fieldsData)-1].IsDynamic, true)
assert.Equal(t, fieldsData[len(fieldsData)-1].Type, schemapb.DataType_JSON)
assert.Equal(t, string(fieldsData[len(fieldsData)-1].GetScalars().GetJsonData().GetData()[0]), "{\"id\":0}")
}
func TestSerialize(t *testing.T) {
parameters := []float32{0.11111, 0.22222}
//assert.Equal(t, string(serialize(parameters)), "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e")
//assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "vector2PlaceholderGroupBytes") // todo
assert.Equal(t, string(serialize(parameters)), "\xa4\x8d\xe3=\xa4\x8dc>")
assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>") // todo
}
func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool {
for key, value := range m1 {
if key == FieldBookIntro {
arr1 := value.([]interface{})
arr2 := m2[key].([]float64)
if len(arr1) != len(arr2) {
return false
}
for j, element := range arr1 {
if element != arr2[j] {
return false
}
}
} else if value != m2[key] {
return false
}
}
for key, value := range m2 {
if key == FieldBookIntro {
continue
} else if value != m1[key] {
return false
}
}
return true
}
func compareRow(m1 map[string]interface{}, m2 map[string]interface{}) bool {
for key, value := range m1 {
if key == FieldBookIntro {
arr1 := value.([]float32)
arr2 := m2[key].([]float32)
if len(arr1) != len(arr2) {
return false
}
for j, element := range arr1 {
if element != arr2[j] {
return false
}
}
} else if (key == "field-binary") || (key == "field-json") {
arr1 := value.([]byte)
arr2 := m2[key].([]byte)
if len(arr1) != len(arr2) {
return false
}
for j, element := range arr1 {
if element != arr2[j] {
return false
}
}
} else if value != m2[key] {
return false
}
}
for key, value := range m2 {
if (key == FieldBookIntro) || (key == "field-binary") || (key == "field-json") {
continue
} else if value != m1[key] {
return false
}
}
return true
}
type CompareFunc func(map[string]interface{}, map[string]interface{}) bool
func compareRows(row1 []map[string]interface{}, row2 []map[string]interface{}, compareFunc CompareFunc) bool {
if len(row1) != len(row2) {
return false
}
for i, row := range row1 {
if !compareFunc(row, row2[i]) {
return false
}
}
return true
}
func TestBuildQueryResp(t *testing.T) {
outputFields := []string{FieldBookID, FieldWordCount, "author", "date"}
rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(3), []float32{0.01, 0.04, 0.09}) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
assert.Equal(t, err, nil)
exceptRows := generateSearchResult()
assert.Equal(t, compareRows(rows, exceptRows, compareRow), true)
}
func newCollectionSchema(coll *schemapb.CollectionSchema) *schemapb.CollectionSchema {
fieldSchema1 := schemapb.FieldSchema{
Name: "field-bool",
DataType: schemapb.DataType_Bool,
}
coll.Fields = append(coll.Fields, &fieldSchema1)
fieldSchema2 := schemapb.FieldSchema{
Name: "field-int8",
DataType: schemapb.DataType_Int8,
}
coll.Fields = append(coll.Fields, &fieldSchema2)
fieldSchema3 := schemapb.FieldSchema{
Name: "field-int16",
DataType: schemapb.DataType_Int16,
}
coll.Fields = append(coll.Fields, &fieldSchema3)
fieldSchema4 := schemapb.FieldSchema{
Name: "field-int32",
DataType: schemapb.DataType_Int32,
}
coll.Fields = append(coll.Fields, &fieldSchema4)
fieldSchema5 := schemapb.FieldSchema{
Name: "field-float",
DataType: schemapb.DataType_Float,
}
coll.Fields = append(coll.Fields, &fieldSchema5)
fieldSchema6 := schemapb.FieldSchema{
Name: "field-double",
DataType: schemapb.DataType_Double,
}
coll.Fields = append(coll.Fields, &fieldSchema6)
fieldSchema7 := schemapb.FieldSchema{
Name: "field-string",
DataType: schemapb.DataType_String,
}
coll.Fields = append(coll.Fields, &fieldSchema7)
fieldSchema8 := schemapb.FieldSchema{
Name: "field-varchar",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "10"},
},
}
coll.Fields = append(coll.Fields, &fieldSchema8)
fieldSchema9 := schemapb.FieldSchema{
Name: "field-json",
DataType: schemapb.DataType_JSON,
IsDynamic: false,
}
coll.Fields = append(coll.Fields, &fieldSchema9)
//fieldSchema10 := schemapb.FieldSchema{
// Name: "$meta",
// DataType: schemapb.DataType_JSON,
// IsDynamic: true,
//}
//coll.Fields = append(coll.Fields, &fieldSchema10)
return coll
}
func withUnsupportField(coll *schemapb.CollectionSchema) *schemapb.CollectionSchema {
fieldSchema10 := schemapb.FieldSchema{
Name: "field-array",
DataType: schemapb.DataType_Array,
IsDynamic: false,
}
coll.Fields = append(coll.Fields, &fieldSchema10)
return coll
}
func withDynamicField(coll *schemapb.CollectionSchema) *schemapb.CollectionSchema {
fieldSchema11 := schemapb.FieldSchema{
Name: "$meta",
DataType: schemapb.DataType_JSON,
IsDynamic: true,
}
coll.Fields = append(coll.Fields, &fieldSchema11)
return coll
}
func newFieldData(fieldDatas []*schemapb.FieldData, firstFieldType schemapb.DataType) []*schemapb.FieldData {
fieldData1 := schemapb.FieldData{
Type: schemapb.DataType_Bool,
FieldName: "field-bool",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: []bool{true, true, true},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData1)
fieldData2 := schemapb.FieldData{
Type: schemapb.DataType_Int8,
FieldName: "field-int8",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{0, 1, 2},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData2)
fieldData3 := schemapb.FieldData{
Type: schemapb.DataType_Int16,
FieldName: "field-int16",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{0, 1, 2},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData3)
fieldData4 := schemapb.FieldData{
Type: schemapb.DataType_Int32,
FieldName: "field-int32",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: []int32{0, 1, 2},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData4)
fieldData5 := schemapb.FieldData{
Type: schemapb.DataType_Float,
FieldName: "field-float",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: []float32{0, 1, 2},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData5)
fieldData6 := schemapb.FieldData{
Type: schemapb.DataType_Double,
FieldName: "field-double",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: []float64{0, 1, 2},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData6)
fieldData7 := schemapb.FieldData{
Type: schemapb.DataType_String,
FieldName: "field-string",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"0", "1", "2"},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData7)
fieldData8 := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldName: "field-varchar",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"0", "1", "2"},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData8)
fieldData9 := schemapb.FieldData{
Type: schemapb.DataType_JSON,
FieldName: "field-json",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: [][]byte{[]byte(`{"XXX": 0}`), []byte(`{"XXX": 0}`), []byte(`{"XXX": 0}`)},
},
},
},
},
IsDynamic: false,
}
fieldDatas = append(fieldDatas, &fieldData9)
fieldData10 := schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: "field-array",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: []*schemapb.ScalarField{
{Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: []bool{true}}}},
{Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: []bool{true}}}},
{Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: []bool{true}}}},
},
},
},
},
},
IsDynamic: false,
}
fieldData11 := schemapb.FieldData{
Type: schemapb.DataType_JSON,
FieldName: "$meta",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: [][]byte{[]byte(`{"XXX": 0, "YYY": "0"}`), []byte(`{"XXX": 1, "YYY": "1"}`), []byte(`{"XXX": 2, "YYY": "2"}`)},
},
},
},
},
IsDynamic: true,
}
fieldDatas = append(fieldDatas, &fieldData11)
switch firstFieldType {
case schemapb.DataType_None:
break
case schemapb.DataType_Bool:
return []*schemapb.FieldData{&fieldData1}
case schemapb.DataType_Int8:
return []*schemapb.FieldData{&fieldData2}
case schemapb.DataType_Int16:
return []*schemapb.FieldData{&fieldData3}
case schemapb.DataType_Int32:
return []*schemapb.FieldData{&fieldData4}
case schemapb.DataType_Float:
return []*schemapb.FieldData{&fieldData5}
case schemapb.DataType_Double:
return []*schemapb.FieldData{&fieldData6}
case schemapb.DataType_String:
return []*schemapb.FieldData{&fieldData7}
case schemapb.DataType_VarChar:
return []*schemapb.FieldData{&fieldData8}
case schemapb.DataType_BinaryVector:
vectorField := generateVectorFieldData(true)
return []*schemapb.FieldData{&vectorField}
case schemapb.DataType_FloatVector:
vectorField := generateVectorFieldData(false)
return []*schemapb.FieldData{&vectorField}
case schemapb.DataType_Array:
return []*schemapb.FieldData{&fieldData10}
case schemapb.DataType_JSON:
return []*schemapb.FieldData{&fieldData9}
default:
return []*schemapb.FieldData{
{
FieldName: "wrong-field-type",
Type: firstFieldType,
},
}
}
return fieldDatas
}
func newSearchResult(results []map[string]interface{}) []map[string]interface{} {
for i, result := range results {
result["field-bool"] = true
result["field-int8"] = int8(i)
result["field-int16"] = int16(i)
result["field-int32"] = int32(i)
result["field-float"] = float32(i)
result["field-double"] = float64(i)
result["field-varchar"] = strconv.Itoa(i)
result["field-string"] = strconv.Itoa(i)
result["field-binary"] = []byte{byte(i)}
result["field-json"] = []byte(`{"XXX": 0}`)
result["XXX"] = float64(i)
result["YYY"] = strconv.Itoa(i)
results[i] = result
}
return results
}
func TestAnyToColumn(t *testing.T) {
data, err := anyToColumns(newSearchResult(generateSearchResult()), newCollectionSchema(generateCollectionSchema(false)))
assert.Equal(t, err, nil)
assert.Equal(t, len(data), 13)
}
func TestBuildQueryResps(t *testing.T) {
outputFields := []string{"XXX", "YYY"}
outputFieldsList := [][]string{outputFields, {"$meta"}, {"$meta", FieldBookID, FieldBookIntro, "YYY"}}
for _, theOutputFields := range outputFieldsList {
rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(3), []float32{0.01, 0.04, 0.09})
assert.Equal(t, err, nil)
exceptRows := newSearchResult(generateSearchResult())
assert.Equal(t, compareRows(rows, exceptRows, compareRow), true)
}
dataTypes := []schemapb.DataType{schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector,
schemapb.DataType_Bool, schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32,
schemapb.DataType_Float, schemapb.DataType_Double,
schemapb.DataType_String, schemapb.DataType_VarChar,
schemapb.DataType_JSON, schemapb.DataType_Array}
for _, dateType := range dataTypes {
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIds(3), []float32{0.01, 0.04, 0.09})
assert.Equal(t, err, nil)
}
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIds(3), []float32{0.01, 0.04, 0.09})
assert.Equal(t, err.Error(), "the type(1000) of field(wrong-field-type) is not supported, use other sdk please")
res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(3), []float32{0.01, 0.04, 0.09})
assert.Equal(t, len(res), 3)
assert.Equal(t, err, nil)
// len(rows) != len(scores), didn't show distance
_, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(3), []float32{0.01, 0.04})
assert.Equal(t, err, nil)
}

View File

@ -30,12 +30,18 @@ import (
"sync"
"time"
"github.com/milvus-io/milvus/pkg/util/merr"
"google.golang.org/grpc/credentials"
management "github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/proxy/accesslog"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/tracer"
"github.com/milvus-io/milvus/pkg/util/interceptor"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/soheilhy/cmux"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"github.com/gin-gonic/gin"
@ -61,7 +67,6 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
@ -81,6 +86,10 @@ type Server struct {
ctx context.Context
wg sync.WaitGroup
proxy types.ProxyComponent
httpListener net.Listener
grpcListener net.Listener
tcpServer cmux.CMux
httpServer *http.Server
grpcInternalServer *grpc.Server
grpcExternalServer *grpc.Server
@ -105,6 +114,21 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
return server, err
}
func authenticate(c *gin.Context) {
if !proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
return
}
username, password, ok := httpserver.ParseUsernamePassword(c)
if ok {
if proxy.PasswordVerify(c, username, password) {
log.Debug("auth successful", zap.String("username", username))
c.Set(httpserver.ContextUsername, username)
return
}
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{httpserver.HTTPReturnCode: httpserver.Code(merr.ErrNeedAuthenticate), httpserver.HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
}
// registerHTTPServer register the http server, panic when failed
func (s *Server) registerHTTPServer() {
// (Embedded Milvus Only) Discard gin logs if logging is disabled.
@ -117,10 +141,40 @@ func (s *Server) registerHTTPServer() {
if !proxy.Params.HTTPCfg.DebugMode.GetAsBool() {
gin.SetMode(gin.ReleaseMode)
}
ginHandler := gin.Default()
apiv1 := ginHandler.Group(apiPathPrefix)
metricsGinHandler := gin.Default()
apiv1 := metricsGinHandler.Group(apiPathPrefix)
httpserver.NewHandlers(s.proxy).RegisterRoutesTo(apiv1)
http.Handle("/", ginHandler)
management.Register(&management.Handler{
Path: "/",
HandlerFunc: nil,
Handler: metricsGinHandler.Handler(),
})
}
func (s *Server) startHTTPServer(errChan chan error) {
defer s.wg.Done()
ginHandler := gin.Default()
ginHandler.Use(func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, DELETE, OPTIONS, PATCH")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
app := ginHandler.Group("/v1", authenticate)
httpserver.NewHandlers(s.proxy).RegisterRoutesToV1(app)
s.httpServer = &http.Server{Handler: ginHandler, ReadHeaderTimeout: time.Second}
errChan <- nil
if err := s.httpServer.Serve(s.httpListener); err != nil && err != cmux.ErrServerClosed {
log.Error("start Proxy http server to listen failed", zap.Error(err))
errChan <- err
return
}
log.Info("Proxy http server exited")
}
func (s *Server) startInternalRPCServer(grpcInternalPort int, errChan chan error) {
@ -146,15 +200,6 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
}
log.Debug("Proxy server listen on tcp", zap.Int("port", grpcPort))
lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort))
if err != nil {
log.Warn("Proxy server failed to listen on", zap.Error(err), zap.Int("port", grpcPort))
errChan <- err
return
}
log.Debug("Proxy server already listen on tcp", zap.Int("port", grpcPort))
limiter, err := s.proxy.GetRateLimiter()
if err != nil {
log.Error("Get proxy rate limiter failed", zap.Int("port", grpcPort), zap.Error(err))
@ -185,6 +230,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
if Params.TLSMode.GetAsInt() == 1 {
creds, err := credentials.NewServerTLSFromFile(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue())
if err != nil {
log.Warn("proxy can't create creds", zap.Error(err))
log.Warn("proxy can't create creds", zap.Error(err))
errChan <- err
return
@ -228,11 +274,12 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
zap.Any("enforcement policy", kaep),
zap.Any("server parameters", kasp))
if err := s.grpcExternalServer.Serve(lis); err != nil {
if err := s.grpcExternalServer.Serve(s.grpcListener); err != nil && err != cmux.ErrServerClosed {
log.Error("failed to serve on Proxy's listener", zap.Error(err))
errChan <- err
return
}
log.Info("Proxy external grpc server exited")
}
func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
@ -283,6 +330,7 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
errChan <- err
return
}
log.Info("Proxy internal grpc server exited")
}
// Start start the Proxy Server
@ -301,6 +349,17 @@ func (s *Server) Run() error {
}
log.Debug("start Proxy server done")
if s.tcpServer != nil {
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := s.tcpServer.Serve(); err != nil && err != cmux.ErrServerClosed {
log.Warn("Proxy server for tcp port failed", zap.Error(err))
return
}
log.Info("Proxy tcp server exited")
}()
}
return nil
}
@ -345,6 +404,80 @@ func (s *Server) init() error {
return err
}
}
{
log.Info("Proxy server listen on tcp", zap.Int("port", Params.Port.GetAsInt()))
var lis net.Listener
var listenErr error
log.Info("Proxy server already listen on tcp", zap.Int("port", Params.Port.GetAsInt()))
lis, listenErr = net.Listen("tcp", ":"+strconv.Itoa(Params.Port.GetAsInt()))
if listenErr != nil {
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
return err
}
if HTTPParams.Enabled.GetAsBool() && Params.TLSMode.GetAsInt() == 0 &&
(HTTPParams.Port.GetValue() == "" || HTTPParams.Port.GetAsInt() == Params.Port.GetAsInt()) {
s.tcpServer = cmux.New(lis)
s.grpcListener = s.tcpServer.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"))
s.httpListener = s.tcpServer.Match(cmux.Any())
} else {
s.grpcListener = lis
}
if HTTPParams.Enabled.GetAsBool() && HTTPParams.Port.GetValue() != "" && HTTPParams.Port.GetAsInt() != Params.Port.GetAsInt() {
if Params.TLSMode.GetAsInt() == 0 {
s.httpListener, listenErr = net.Listen("tcp", ":"+strconv.Itoa(HTTPParams.Port.GetAsInt()))
if listenErr != nil {
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
return err
}
} else if Params.TLSMode.GetAsInt() == 1 {
creds, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue())
if err != nil {
log.Error("proxy can't create creds", zap.Error(err))
return err
}
s.httpListener, listenErr = tls.Listen("tcp", ":"+strconv.Itoa(HTTPParams.Port.GetAsInt()), &tls.Config{
Certificates: []tls.Certificate{creds},
})
if listenErr != nil {
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
return listenErr
}
} else if Params.TLSMode.GetAsInt() == 2 {
cert, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue())
if err != nil {
log.Error("proxy cant load x509 key pair", zap.Error(err))
return err
}
certPool := x509.NewCertPool()
rootBuf, err := ioutil.ReadFile(Params.CaPemPath.GetValue())
if err != nil {
log.Error("failed read ca pem", zap.Error(err))
return err
}
if !certPool.AppendCertsFromPEM(rootBuf) {
log.Warn("fail to append ca to cert")
return fmt.Errorf("fail to append ca to cert")
}
tlsConf := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
MinVersion: tls.VersionTLS13,
}
s.httpListener, listenErr = tls.Listen("tcp", ":"+strconv.Itoa(HTTPParams.Port.GetAsInt()), tlsConf)
if listenErr != nil {
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
return listenErr
}
}
}
}
{
s.startExternalRPCServer(Params.Port.GetAsInt(), errChan)
if err := <-errChan; err != nil {
@ -355,7 +488,7 @@ func (s *Server) init() error {
if HTTPParams.Enabled.GetAsBool() {
registerHTTPHandlerOnce.Do(func() {
log.Info("register http server of proxy")
log.Info("register Proxy http server")
s.registerHTTPServer()
})
}
@ -474,6 +607,17 @@ func (s *Server) start() error {
return err
}
if s.httpListener != nil {
log.Info("start Proxy http server")
errChan := make(chan error, 1)
s.wg.Add(1)
go s.startHTTPServer(errChan)
if err := <-errChan; err != nil {
log.Error("failed to create http rpc server", zap.Error(err))
return err
}
}
return nil
}
@ -496,9 +640,16 @@ func (s *Server) Stop() error {
log.Debug("Graceful stop grpc internal server...")
s.grpcInternalServer.GracefulStop()
}
if s.grpcExternalServer != nil {
log.Debug("Graceful stop grpc external server...")
if s.tcpServer != nil {
log.Info("Graceful stop Proxy tcp server...")
s.tcpServer.Close()
} else if s.grpcExternalServer != nil {
log.Info("Graceful stop grpc external server...")
s.grpcExternalServer.GracefulStop()
if s.httpServer != nil {
log.Info("Graceful stop grpc http server...")
s.httpServer.Close()
}
}
}()
gracefulWg.Wait()

View File

@ -966,12 +966,14 @@ func waitForGrpcReady(opt *WaitOption) {
ch <- err
return
}
_, err = grpc.Dial(address, grpc.WithBlock(), grpc.WithTransportCredentials(creds))
conn, err := grpc.Dial(address, grpc.WithBlock(), grpc.WithTransportCredentials(creds))
ch <- err
conn.Close()
return
}
if _, err := grpc.Dial(address, grpc.WithBlock(), grpc.WithInsecure()); true {
if conn, err := grpc.Dial(address, grpc.WithBlock(), grpc.WithInsecure()); true {
ch <- err
conn.Close()
}
}()
@ -1522,7 +1524,7 @@ func Test_NewServer_HTTPServer_Enabled(t *testing.T) {
t.Fatalf("test should have panicked but did not")
}
}()
// if disable workds path not registered, so it shall not panic
// if disable works path not registered, so it shall not panic
server.registerHTTPServer()
}
@ -1616,6 +1618,86 @@ func Test_NewServer_TLS_FileNotExisted(t *testing.T) {
server.Stop()
}
func Test_NewHTTPServer_TLS_TwoWay(t *testing.T) {
server := getServer(t)
Params := &paramtable.Get().ProxyGrpcServerCfg
paramtable.Get().Save(Params.TLSMode.Key, "2")
paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem")
paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key")
paramtable.Get().Save(Params.CaPemPath.Key, "../../../configs/cert/ca.pem")
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "8080")
err := runAndWaitForServerReady(server)
assert.Nil(t, err)
assert.NotNil(t, server.grpcExternalServer)
err = server.Stop()
assert.Nil(t, err)
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "19529")
err = runAndWaitForServerReady(server)
assert.NotNil(t, err)
server.Stop()
}
func Test_NewHTTPServer_TLS_OneWay(t *testing.T) {
server := getServer(t)
Params := &paramtable.Get().ProxyGrpcServerCfg
paramtable.Get().Save(Params.TLSMode.Key, "1")
paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem")
paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key")
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "8080")
err := runAndWaitForServerReady(server)
assert.Nil(t, err)
assert.NotNil(t, server.grpcExternalServer)
err = server.Stop()
assert.Nil(t, err)
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "19529")
err = runAndWaitForServerReady(server)
assert.NotNil(t, err)
server.Stop()
}
func Test_NewHTTPServer_TLS_FileNotExisted(t *testing.T) {
server := getServer(t)
Params := &paramtable.Get().ProxyGrpcServerCfg
paramtable.Get().Save(Params.TLSMode.Key, "1")
paramtable.Get().Save(Params.ServerPemPath.Key, "../not/existed/server.pem")
paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key")
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "8080")
err := runAndWaitForServerReady(server)
assert.NotNil(t, err)
server.Stop()
paramtable.Get().Save(Params.TLSMode.Key, "2")
paramtable.Get().Save(Params.ServerPemPath.Key, "../not/existed/server.pem")
paramtable.Get().Save(Params.CaPemPath.Key, "../../../configs/cert/ca.pem")
err = runAndWaitForServerReady(server)
assert.NotNil(t, err)
server.Stop()
paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem")
paramtable.Get().Save(Params.CaPemPath.Key, "../not/existed/ca.pem")
err = runAndWaitForServerReady(server)
assert.NotNil(t, err)
server.Stop()
paramtable.Get().Save(Params.CaPemPath.Key, "service.go")
err = runAndWaitForServerReady(server)
assert.NotNil(t, err)
server.Stop()
}
func Test_NewServer_GetVersion(t *testing.T) {
req := &milvuspb.GetVersionRequest{}
t.Run("test get version failed", func(t *testing.T) {

View File

@ -77,6 +77,22 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
log.Warn("GetCurUserFromContext fail", zap.Error(err))
return ctx, err
}
return privilegeInterceptor(ctx, privilegeExt, username, req)
}
func PrivilegeInterceptorWithUsername(ctx context.Context, username string, req interface{}) (context.Context, error) {
if !Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
return ctx, nil
}
log.Debug("PrivilegeInterceptor", zap.String("type", reflect.TypeOf(req).String()))
privilegeExt, err := funcutil.GetPrivilegeExtObj(req)
if err != nil {
log.Warn("GetPrivilegeExtObj err", zap.Error(err))
return ctx, nil
}
return privilegeInterceptor(ctx, privilegeExt, username, req)
}
func privilegeInterceptor(ctx context.Context, privilegeExt commonpb.PrivilegeExt, username string, req interface{}) (context.Context, error) {
if username == util.UserRoot {
return ctx, nil
}

View File

@ -33,6 +33,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
@ -820,8 +821,8 @@ func GetCurUserFromContext(ctx context.Context) (string, error) {
if !ok {
return "", fmt.Errorf("fail to get md from the context")
}
authorization := md[strings.ToLower(util.HeaderAuthorize)]
if len(authorization) < 1 {
authorization, ok := md[strings.ToLower(util.HeaderAuthorize)]
if !ok || len(authorization) < 1 {
return "", fmt.Errorf("fail to get authorization from the md, authorize:[%s]", util.HeaderAuthorize)
}
token := authorization[0]
@ -856,6 +857,10 @@ func GetRole(username string) ([]string, error) {
return globalMetaCache.GetUserRole(username), nil
}
func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
return passwordVerify(ctx, username, rawPwd, globalMetaCache)
}
// PasswordVerify verify password
func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCache Cache) bool {
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.

View File

@ -36,7 +36,7 @@ const (
embededBits
retriableFlag = 1 << 20
rootReasonCodeMask = (1 << 16) - 1
RootReasonCodeMask = (1 << 16) - 1
CanceledCode int32 = 10000
TimeoutCode int32 = 10001
@ -125,6 +125,15 @@ var (
// field related
ErrFieldNotFound = newMilvusError("field not found", 1700, false)
// high-level restful api related
ErrNeedAuthenticate = newMilvusError("user hasn't authenticate", 1800, false)
ErrIncorrectParameterFormat = newMilvusError("can only accept json format request", 1801, false)
ErrMissingRequiredParameters = newMilvusError("missing required parameters", 1802, false)
ErrMarshalCollectionSchema = newMilvusError("fail to marshal collection schema", 1803, false)
ErrInvalidInsertData = newMilvusError("fail to deal the insert data", 1804, false)
ErrInvalidSearchResult = newMilvusError("fail to parse search result", 1805, false)
ErrCheckPrimaryKey = newMilvusError("please check the primary key and its' type can only in [int, string]", 1806, false)
// Do NOT export this,
// never allow programmer using this, keep only for converting unknown error to milvusError
errUnexpected = newMilvusError("unexpected error", (1<<16)-1, false)

View File

@ -3,6 +3,7 @@ package paramtable
type httpConfig struct {
Enabled ParamItem `refreshable:"false"`
DebugMode ParamItem `refreshable:"false"`
Port ParamItem `refreshable:"false"`
}
func (p *httpConfig) init(base *BaseTable) {
@ -23,4 +24,13 @@ func (p *httpConfig) init(base *BaseTable) {
Export: true,
}
p.DebugMode.Init(base.mgr)
p.Port = ParamItem{
Key: "proxy.http.port",
Version: "2.1.0",
Doc: "high-level restful api",
PanicIfEmpty: false,
Export: true,
}
p.Port.Init(base.mgr)
}

View File

@ -12,4 +12,5 @@ func TestHTTPConfig_Init(t *testing.T) {
cf := params.HTTPCfg
assert.Equal(t, cf.Enabled.GetAsBool(), true)
assert.Equal(t, cf.DebugMode.GetAsBool(), false)
assert.Equal(t, cf.Port.GetValue(), "")
}