mirror of https://github.com/milvus-io/milvus.git
support high-level RESTFUL API, listen on the same port as grpc. (#25108)
Signed-off-by: PowderLi <min.li@zilliz.com>pull/26133/head
parent
28cea5e763
commit
a7eecb1be0
|
@ -77,6 +77,8 @@ issues:
|
|||
- G304
|
||||
# Deferring unsafe method like *os.File Close
|
||||
- G307
|
||||
# TLS MinVersion too low
|
||||
- G402
|
||||
# Use of weak random number generator math/rand
|
||||
- G404
|
||||
# Maximum issues count per one linter. Set to 0 to disable. Default is 50.
|
||||
|
|
17
go.mod
17
go.mod
|
@ -4,7 +4,7 @@ go 1.18
|
|||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||
github.com/aliyun/credentials-go v1.2.6
|
||||
github.com/aliyun/credentials-go v1.2.7
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210826220005-b48c857c3a0e
|
||||
github.com/antonmedv/expr v1.8.9
|
||||
github.com/apache/arrow/go/v8 v8.0.0-20220322092137-778b1772fd20
|
||||
|
@ -50,6 +50,9 @@ require (
|
|||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/text v0.9.0
|
||||
google.golang.org/grpc v1.54.0
|
||||
google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
gorm.io/driver/mysql v1.3.5
|
||||
gorm.io/gorm v1.23.8
|
||||
stathat.com/c/consistent v1.0.0
|
||||
|
@ -169,7 +172,7 @@ require (
|
|||
github.com/rs/xid v1.2.1 // indirect
|
||||
github.com/shirou/gopsutil/v3 v3.22.9
|
||||
github.com/sirupsen/logrus v1.8.1 // indirect
|
||||
github.com/soheilhy/cmux v0.1.5 // indirect
|
||||
github.com/soheilhy/cmux v0.1.5
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/spf13/afero v1.6.0 // indirect
|
||||
github.com/spf13/cast v1.3.1
|
||||
|
@ -209,15 +212,19 @@ require (
|
|||
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
|
||||
gonum.org/v1/gonum v0.9.3 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/ini.v1 v1.62.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
sigs.k8s.io/yaml v1.2.0 // indirect
|
||||
)
|
||||
|
||||
require github.com/tidwall/gjson v1.14.4
|
||||
|
||||
require (
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
)
|
||||
|
||||
replace (
|
||||
github.com/apache/pulsar-client-go => github.com/milvus-io/pulsar-client-go v0.6.10
|
||||
github.com/bketelsen/crypt => github.com/bketelsen/crypt v0.0.4 // Fix security alert for core-os/etcd
|
||||
|
|
10
go.sum
10
go.sum
|
@ -82,8 +82,8 @@ github.com/alibabacloud-go/debug v0.0.0-20190504072949-9472017b5c68 h1:NqugFkGxx
|
|||
github.com/alibabacloud-go/debug v0.0.0-20190504072949-9472017b5c68/go.mod h1:6pb/Qy8c+lqua8cFpEy7g39NRRqOWc3rOwAy8m5Y2BY=
|
||||
github.com/alibabacloud-go/tea v1.1.8 h1:vFF0707fqjGiQTxrtMnIXRjOCvQXf49CuDVRtTopmwU=
|
||||
github.com/alibabacloud-go/tea v1.1.8/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4=
|
||||
github.com/aliyun/credentials-go v1.2.6 h1:dSMxpj4uXZj0MYOsEyljlssHzfdHw/M84iQ5QKF0Uxg=
|
||||
github.com/aliyun/credentials-go v1.2.6/go.mod h1:/KowD1cfGSLrLsH28Jr8W+xwoId0ywIy5lNzDz6O1vw=
|
||||
github.com/aliyun/credentials-go v1.2.7 h1:gLtFylxLZ1TWi1pStIt1O6a53GFU1zkNwjtJir2B4ow=
|
||||
github.com/aliyun/credentials-go v1.2.7/go.mod h1:/KowD1cfGSLrLsH28Jr8W+xwoId0ywIy5lNzDz6O1vw=
|
||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||
|
@ -812,6 +812,12 @@ github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
|
|||
github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s=
|
||||
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
|
||||
github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M=
|
||||
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw=
|
||||
github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk=
|
||||
github.com/tklauser/numcpus v0.4.0 h1:E53Dm1HjH1/R2/aoCtXtPgzmElmn51aOkhCFSuZq//o=
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package httpserver
|
||||
|
||||
const (
|
||||
ContextUsername = "username"
|
||||
VectorCollectionsPath = "/vector/collections"
|
||||
VectorCollectionsCreatePath = "/vector/collections/create"
|
||||
VectorCollectionsDescribePath = "/vector/collections/describe"
|
||||
VectorCollectionsDropPath = "/vector/collections/drop"
|
||||
VectorInsertPath = "/vector/insert"
|
||||
VectorSearchPath = "/vector/search"
|
||||
VectorGetPath = "/vector/get"
|
||||
VectorQueryPath = "/vector/query"
|
||||
VectorDeletePath = "/vector/delete"
|
||||
|
||||
ShardNumDefault = 1
|
||||
|
||||
EnableDynamic = true
|
||||
EnableAutoID = true
|
||||
DisableAutoID = false
|
||||
|
||||
HTTPCollectionName = "collectionName"
|
||||
HTTPDbName = "dbName"
|
||||
DefaultDbName = "default"
|
||||
DefaultIndexName = "vector_idx"
|
||||
DefaultOutputFields = "*"
|
||||
|
||||
HTTPReturnCode = "code"
|
||||
HTTPReturnMessage = "message"
|
||||
HTTPReturnData = "data"
|
||||
|
||||
HTTPReturnFieldName = "name"
|
||||
HTTPReturnFieldType = "type"
|
||||
HTTPReturnFieldPrimaryKey = "primaryKey"
|
||||
HTTPReturnFieldAutoID = "autoId"
|
||||
HTTPReturnDescription = "description"
|
||||
|
||||
HTTPReturnIndexName = "indexName"
|
||||
HTTPReturnIndexField = "fieldName"
|
||||
HTTPReturnIndexMetricsType = "metricType"
|
||||
|
||||
HTTPReturnDistance = "distance"
|
||||
|
||||
DefaultMetricType = "L2"
|
||||
DefaultPrimaryFieldName = "id"
|
||||
DefaultVectorFieldName = "vector"
|
||||
|
||||
Dim = "dim"
|
||||
)
|
||||
|
||||
const (
|
||||
ParamAnnsField = "anns_field"
|
||||
Params = "params"
|
||||
ParamRoundDecimal = "round_decimal"
|
||||
ParamOffset = "offset"
|
||||
ParamLimit = "limit"
|
||||
BoundedTimestamp = 2
|
||||
)
|
|
@ -0,0 +1,625 @@
|
|||
package httpserver
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proxy"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func checkAuthorization(c *gin.Context, req interface{}) error {
|
||||
if proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
username, ok := c.Get(ContextUsername)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
return merr.ErrNeedAuthenticate
|
||||
}
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return authErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool {
|
||||
if dbName == DefaultDbName {
|
||||
return true
|
||||
}
|
||||
response, err := h.proxy.ListDatabases(c, &milvuspb.ListDatabasesRequest{})
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return false
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
return false
|
||||
}
|
||||
for _, db := range response.DbNames {
|
||||
if db == dbName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrDatabaseNotfound), HTTPReturnMessage: merr.ErrDatabaseNotfound.Error()})
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
req := milvuspb.DescribeCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
}
|
||||
if needAuth {
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
response, err := h.proxy.DescribeCollection(c, &req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return nil, err
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
return nil, errors.New(response.Status.Reason)
|
||||
}
|
||||
primaryField, ok := getPrimaryField(response.Schema)
|
||||
if ok && primaryField.AutoID && !response.Schema.AutoID {
|
||||
log.Warn("primary filed autoID VS schema autoID", zap.String("collectionName", collectionName), zap.Bool("primary Field", primaryField.AutoID), zap.Bool("schema", response.Schema.AutoID))
|
||||
response.Schema.AutoID = EnableAutoID
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (h *Handlers) hasCollection(c *gin.Context, dbName string, collectionName string) (bool, error) {
|
||||
req := milvuspb.HasCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
}
|
||||
response, err := h.proxy.HasCollection(c, &req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return false, err
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
return false, errors.New(response.Status.Reason)
|
||||
} else {
|
||||
return response.Value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) RegisterRoutesToV1(router gin.IRouter) {
|
||||
router.GET(VectorCollectionsPath, h.listCollections)
|
||||
router.POST(VectorCollectionsCreatePath, h.createCollection)
|
||||
router.GET(VectorCollectionsDescribePath, h.getCollectionDetails)
|
||||
router.POST(VectorCollectionsDropPath, h.dropCollection)
|
||||
router.POST(VectorQueryPath, h.query)
|
||||
router.POST(VectorGetPath, h.get)
|
||||
router.POST(VectorDeletePath, h.delete)
|
||||
router.POST(VectorInsertPath, h.insert)
|
||||
router.POST(VectorSearchPath, h.search)
|
||||
}
|
||||
|
||||
func (h *Handlers) listCollections(c *gin.Context) {
|
||||
dbName := c.DefaultQuery(HTTPDbName, DefaultDbName)
|
||||
req := milvuspb.ShowCollectionsRequest{
|
||||
DbName: dbName,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, dbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.ShowCollections(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
var collections []string
|
||||
if response.CollectionNames != nil {
|
||||
collections = response.CollectionNames
|
||||
} else {
|
||||
collections = []string{}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: collections})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) createCollection(c *gin.Context) {
|
||||
httpReq := CreateCollectionReq{
|
||||
DbName: DefaultDbName,
|
||||
MetricType: DefaultMetricType,
|
||||
PrimaryField: DefaultPrimaryFieldName,
|
||||
VectorField: DefaultVectorFieldName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
log.Warn("high level restful api, the parameter of create collection is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.Dimension == 0 {
|
||||
log.Warn("high level restful api, create collection require parameters: [collectionName, dimension], but miss", zap.Any("request", httpReq))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
schema, err := proto.Marshal(&schemapb.CollectionSchema{
|
||||
Name: httpReq.CollectionName,
|
||||
Description: httpReq.Description,
|
||||
AutoID: EnableAutoID,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.StartOfUserFieldID,
|
||||
Name: httpReq.PrimaryField,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
AutoID: EnableAutoID,
|
||||
}, {
|
||||
FieldID: common.StartOfUserFieldID + 1,
|
||||
Name: httpReq.VectorField,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: Dim,
|
||||
Value: strconv.FormatInt(int64(httpReq.Dimension), 10),
|
||||
},
|
||||
},
|
||||
AutoID: DisableAutoID,
|
||||
},
|
||||
},
|
||||
EnableDynamicField: EnableDynamic,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, marshal collection schema fail", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.CreateCollectionRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Schema: schema,
|
||||
ShardsNum: ShardNumDefault,
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel_Bounded,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.CreateCollection(c, &req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return
|
||||
} else if response.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
|
||||
return
|
||||
}
|
||||
|
||||
response, err = h.proxy.CreateIndex(c, &milvuspb.CreateIndexRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
FieldName: httpReq.VectorField,
|
||||
IndexName: DefaultIndexName,
|
||||
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}},
|
||||
})
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return
|
||||
} else if response.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
|
||||
return
|
||||
}
|
||||
response, err = h.proxy.LoadCollection(c, &milvuspb.LoadCollectionRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
})
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
return
|
||||
} else if response.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
|
||||
}
|
||||
|
||||
func (h *Handlers) getCollectionDetails(c *gin.Context) {
|
||||
collectionName := c.Query(HTTPCollectionName)
|
||||
if collectionName == "" {
|
||||
log.Warn("high level restful api, desc collection require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
dbName := c.DefaultQuery(HTTPDbName, DefaultDbName)
|
||||
if !h.checkDatabase(c, dbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, dbName, collectionName, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
stateResp, stateErr := h.proxy.GetLoadState(c, &milvuspb.GetLoadStateRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
collLoadState := ""
|
||||
if stateErr != nil {
|
||||
log.Warn("get collection load state fail", zap.String("collection", collectionName), zap.String("err", stateErr.Error()))
|
||||
} else if stateResp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
log.Warn("get collection load state fail", zap.String("collection", collectionName), zap.String("err", stateResp.Status.Reason))
|
||||
} else {
|
||||
collLoadState = stateResp.State.String()
|
||||
}
|
||||
vectorField := ""
|
||||
for _, field := range coll.Schema.Fields {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
||||
vectorField = field.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
indexResp, indexErr := h.proxy.DescribeIndex(c, &milvuspb.DescribeIndexRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldName: vectorField,
|
||||
})
|
||||
var indexDesc []gin.H
|
||||
if indexErr != nil {
|
||||
indexDesc = []gin.H{}
|
||||
log.Warn("get indexes description fail", zap.String("collection", collectionName), zap.String("vectorField", vectorField), zap.String("err", indexErr.Error()))
|
||||
} else if indexResp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
indexDesc = []gin.H{}
|
||||
log.Warn("get indexes description fail", zap.String("collection", collectionName), zap.String("vectorField", vectorField), zap.String("err", indexResp.Status.Reason))
|
||||
} else {
|
||||
indexDesc = printIndexes(indexResp.IndexDescriptions)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{
|
||||
HTTPCollectionName: coll.CollectionName,
|
||||
HTTPReturnDescription: coll.Schema.Description,
|
||||
"fields": printFields(coll.Schema.Fields),
|
||||
"indexes": indexDesc,
|
||||
"load": collLoadState,
|
||||
"shardsNum": coll.ShardsNum,
|
||||
"enableDynamic": coll.Schema.EnableDynamicField,
|
||||
}})
|
||||
}
|
||||
|
||||
func (h *Handlers) dropCollection(c *gin.Context) {
|
||||
httpReq := DropCollectionReq{
|
||||
DbName: DefaultDbName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
log.Warn("high level restful api, the parameter of drop collection is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" {
|
||||
log.Warn("high level restful api, drop collection require parameter: [collectionName], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.DropCollectionRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
has, err := h.hasCollection(c, httpReq.DbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !has {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()})
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.DropCollection(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.ErrorCode)), HTTPReturnMessage: response.Reason})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) query(c *gin.Context) {
|
||||
httpReq := QueryReq{
|
||||
DbName: DefaultDbName,
|
||||
Limit: 100,
|
||||
OutputFields: []string{DefaultOutputFields},
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
log.Warn("high level restful api, the parameter of query is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.Filter == "" {
|
||||
log.Warn("high level restful api, query require parameter: [collectionName, filter], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.QueryRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Expr: httpReq.Filter,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
QueryParams: []*commonpb.KeyValuePair{},
|
||||
}
|
||||
if httpReq.Offset > 0 {
|
||||
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
|
||||
}
|
||||
if httpReq.Limit > 0 {
|
||||
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Query(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with query result", zap.Any("response", response), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) get(c *gin.Context) {
|
||||
httpReq := GetReq{
|
||||
DbName: DefaultDbName,
|
||||
OutputFields: []string{DefaultOutputFields},
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
log.Warn("high level restful api, the parameter of get is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.ID == nil {
|
||||
log.Warn("high level restful api, get require parameter: [collectionName, id], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.QueryRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
|
||||
return
|
||||
}
|
||||
req.Expr = filter
|
||||
response, err := h.proxy.Query(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with get result", zap.Any("response", response), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
|
||||
log.Error("get resultIS: ", zap.Any("res", outputData))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) delete(c *gin.Context) {
|
||||
httpReq := DeleteReq{
|
||||
DbName: DefaultDbName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
log.Warn("high level restful api, the parameter of delete is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.ID == nil {
|
||||
log.Warn("high level restful api, delete require parameter: [collectionName, id], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.DeleteRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
|
||||
return
|
||||
}
|
||||
req.Expr = filter
|
||||
response, err := h.proxy.Delete(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) insert(c *gin.Context) {
|
||||
httpReq := InsertReq{
|
||||
DbName: DefaultDbName,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
singleInsertReq := SingleInsertReq{
|
||||
DbName: DefaultDbName,
|
||||
}
|
||||
if err = c.ShouldBindBodyWith(&singleInsertReq, binding.JSON); err != nil {
|
||||
log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
httpReq.DbName = singleInsertReq.DbName
|
||||
httpReq.CollectionName = singleInsertReq.CollectionName
|
||||
httpReq.Data = []map[string]interface{}{singleInsertReq.Data}
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.Data == nil {
|
||||
log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
req := milvuspb.InsertRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
PartitionName: "_default",
|
||||
NumRows: uint32(len(httpReq.Data)),
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
err = checkAndSetData(string(body.([]byte)), coll, &httpReq)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()})
|
||||
return
|
||||
}
|
||||
req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()})
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Insert(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
switch response.IDs.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}})
|
||||
case *schemapb.IDs_StrId:
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}})
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) search(c *gin.Context) {
|
||||
httpReq := SearchReq{
|
||||
DbName: DefaultDbName,
|
||||
Limit: 100,
|
||||
}
|
||||
if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil {
|
||||
log.Warn("high level restful api, the parameter of search is incorrect", zap.Any("request", httpReq), zap.Error(err))
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()})
|
||||
return
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.Vector == nil {
|
||||
log.Warn("high level restful api, search require parameter: [collectionName, vector], but miss")
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()})
|
||||
return
|
||||
}
|
||||
params := map[string]interface{}{ //auto generated mapping
|
||||
"level": int(commonpb.ConsistencyLevel_Bounded),
|
||||
}
|
||||
bs, _ := json.Marshal(params)
|
||||
searchParams := []*commonpb.KeyValuePair{
|
||||
{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
|
||||
{Key: Params, Value: string(bs)},
|
||||
{Key: ParamRoundDecimal, Value: "-1"},
|
||||
{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)},
|
||||
}
|
||||
req := milvuspb.SearchRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: httpReq.Filter,
|
||||
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
SearchParams: searchParams,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
Nq: int64(1),
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Search(c, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()})
|
||||
} else if response.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: code(int32(response.Status.ErrorCode)), HTTPReturnMessage: response.Status.Reason})
|
||||
} else {
|
||||
if response.Results.TopK == int64(0) {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}})
|
||||
} else {
|
||||
outputData, err := buildQueryResp(response.Results.TopK, response.Results.OutputFields, response.Results.FieldsData, response.Results.Ids, response.Results.Scores)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with search result", zap.Any("result", response.Results), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,60 @@
|
|||
package httpserver
|
||||
|
||||
type CreateCollectionReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
Dimension int32 `json:"dimension" validate:"required"`
|
||||
Description string `json:"description"`
|
||||
MetricType string `json:"metricType"`
|
||||
PrimaryField string `json:"primaryField"`
|
||||
VectorField string `json:"vectorField"`
|
||||
}
|
||||
|
||||
type DropCollectionReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
}
|
||||
|
||||
type QueryReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
Filter string `json:"filter" validate:"required"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
}
|
||||
|
||||
type GetReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
ID interface{} `json:"id" validate:"required"`
|
||||
}
|
||||
|
||||
type DeleteReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
ID interface{} `json:"id" validate:"required"`
|
||||
}
|
||||
|
||||
type InsertReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
Data []map[string]interface{} `json:"data" validate:"required"`
|
||||
}
|
||||
|
||||
type SingleInsertReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
Data map[string]interface{} `json:"data" validate:"required"`
|
||||
}
|
||||
|
||||
type SearchReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
Filter string `json:"filter"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
Vector []float32 `json:"vector"`
|
||||
}
|
|
@ -0,0 +1,871 @@
|
|||
package httpserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/parameterutil.go"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
)
|
||||
|
||||
func ParseUsernamePassword(c *gin.Context) (string, string, bool) {
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
if !ok {
|
||||
auth := c.Request.Header.Get("Authorization")
|
||||
if auth != "" {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
if token != auth {
|
||||
i := strings.IndexAny(token, ":")
|
||||
if i != -1 {
|
||||
username = token[:i]
|
||||
password = token[i+1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
c.Header("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
}
|
||||
return username, password, username != "" && password != ""
|
||||
}
|
||||
|
||||
// find the primary field of collection
|
||||
func getPrimaryField(schema *schemapb.CollectionSchema) (*schemapb.FieldSchema, bool) {
|
||||
for _, field := range schema.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
return field, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func joinArray(data interface{}) string {
|
||||
var buffer bytes.Buffer
|
||||
arr := reflect.ValueOf(data)
|
||||
|
||||
for i := 0; i < arr.Len(); i++ {
|
||||
if i > 0 {
|
||||
buffer.WriteString(",")
|
||||
}
|
||||
|
||||
buffer.WriteString(fmt.Sprintf("%v", arr.Index(i)))
|
||||
}
|
||||
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func convertRange(field *schemapb.FieldSchema, result gjson.Result) (string, error) {
|
||||
var resultStr string
|
||||
fieldType := field.DataType
|
||||
|
||||
if fieldType == schemapb.DataType_Int64 {
|
||||
var dataArray []int64
|
||||
for _, data := range result.Array() {
|
||||
if data.Type == gjson.String {
|
||||
value, err := cast.ToInt64E(data.Str)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dataArray = append(dataArray, value)
|
||||
} else {
|
||||
value, err := cast.ToInt64E(data.Raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dataArray = append(dataArray, value)
|
||||
}
|
||||
}
|
||||
resultStr = joinArray(dataArray)
|
||||
} else if fieldType == schemapb.DataType_VarChar {
|
||||
var dataArray []string
|
||||
for _, data := range result.Array() {
|
||||
value, err := cast.ToStringE(data.Str)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dataArray = append(dataArray, value)
|
||||
}
|
||||
resultStr = joinArray(dataArray)
|
||||
}
|
||||
return resultStr, nil
|
||||
}
|
||||
|
||||
// generate the expression: $primaryFieldName in [1,2,3]
|
||||
func checkGetPrimaryKey(coll *schemapb.CollectionSchema, idResult gjson.Result) (string, error) {
|
||||
primaryField, ok := getPrimaryField(coll)
|
||||
if !ok {
|
||||
return "", errors.New("fail to find primary key from collection description")
|
||||
}
|
||||
resultStr, err := convertRange(primaryField, idResult)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
filter := primaryField.Name + " in [" + resultStr + "]"
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
// --------------------- collection details --------------------- //
|
||||
func printFields(fields []*schemapb.FieldSchema) []gin.H {
|
||||
var res []gin.H
|
||||
for _, field := range fields {
|
||||
fieldDetail := gin.H{
|
||||
HTTPReturnFieldName: field.Name,
|
||||
HTTPReturnFieldPrimaryKey: field.IsPrimaryKey,
|
||||
HTTPReturnFieldAutoID: field.AutoID,
|
||||
HTTPReturnDescription: field.Description,
|
||||
}
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
|
||||
dim, _ := getDim(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
|
||||
} else if field.DataType == schemapb.DataType_VarChar {
|
||||
maxLength, _ := parameterutil.GetMaxLength(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
|
||||
} else {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
}
|
||||
res = append(res, fieldDetail)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func getMetricType(pairs []*commonpb.KeyValuePair) string {
|
||||
metricType := DefaultMetricType
|
||||
for _, pair := range pairs {
|
||||
if pair.Key == common.MetricTypeKey {
|
||||
metricType = pair.Value
|
||||
break
|
||||
}
|
||||
}
|
||||
return metricType
|
||||
}
|
||||
|
||||
func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H {
|
||||
var res []gin.H
|
||||
for _, index := range indexes {
|
||||
res = append(res, gin.H{
|
||||
HTTPReturnIndexName: index.IndexName,
|
||||
HTTPReturnIndexField: index.FieldName,
|
||||
HTTPReturnIndexMetricsType: getMetricType(index.Params),
|
||||
})
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// --------------------- insert param --------------------- //
|
||||
|
||||
func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionResponse, req *InsertReq) error {
|
||||
var reallyDataArray []map[string]interface{}
|
||||
dataResult := gjson.Get(body, "data")
|
||||
dataResultArray := dataResult.Array()
|
||||
if len(dataResultArray) == 0 {
|
||||
return errors.New("data is required")
|
||||
}
|
||||
|
||||
var fieldNames []string
|
||||
for _, field := range collDescResp.Schema.Fields {
|
||||
fieldNames = append(fieldNames, field.Name)
|
||||
}
|
||||
|
||||
for _, data := range dataResultArray {
|
||||
reallyData := map[string]interface{}{}
|
||||
var vectorArray []float32
|
||||
var binaryArray []byte
|
||||
if data.Type == gjson.JSON {
|
||||
for _, field := range collDescResp.Schema.Fields {
|
||||
fieldType := field.DataType
|
||||
fieldName := field.Name
|
||||
|
||||
dataString := gjson.Get(data.Raw, fieldName).String()
|
||||
|
||||
if field.IsPrimaryKey && collDescResp.Schema.AutoID {
|
||||
if dataString != "" {
|
||||
return fmt.Errorf("fieldName %s AutoId already open, not support insert data %s", fieldName, dataString)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
for _, vector := range gjson.Get(data.Raw, fieldName).Array() {
|
||||
vectorArray = append(vectorArray, cast.ToFloat32(vector.Num))
|
||||
}
|
||||
reallyData[fieldName] = vectorArray
|
||||
case schemapb.DataType_BinaryVector:
|
||||
for _, vector := range gjson.Get(data.Raw, fieldName).Array() {
|
||||
binaryArray = append(binaryArray, cast.ToUint8(vector.Num))
|
||||
}
|
||||
reallyData[fieldName] = binaryArray
|
||||
case schemapb.DataType_Bool:
|
||||
result, err := cast.ToBoolE(dataString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataString %s cast to bool error: %s", dataString, err.Error())
|
||||
}
|
||||
reallyData[fieldName] = result
|
||||
case schemapb.DataType_Int8:
|
||||
result, err := cast.ToInt8E(dataString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataString %s cast to int8 error: %s", dataString, err.Error())
|
||||
}
|
||||
reallyData[fieldName] = result
|
||||
case schemapb.DataType_Int16:
|
||||
result, err := cast.ToInt16E(dataString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataString %s cast to int16 error: %s", dataString, err.Error())
|
||||
}
|
||||
reallyData[fieldName] = result
|
||||
case schemapb.DataType_Int32:
|
||||
result, err := cast.ToInt32E(dataString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataString %s cast to int32 error: %s", dataString, err.Error())
|
||||
}
|
||||
reallyData[fieldName] = result
|
||||
case schemapb.DataType_Int64:
|
||||
result, err := cast.ToInt64E(dataString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataString %s cast to int64 error: %s", dataString, err.Error())
|
||||
}
|
||||
reallyData[fieldName] = result
|
||||
case schemapb.DataType_JSON:
|
||||
reallyData[fieldName] = []byte(dataString)
|
||||
case schemapb.DataType_Float:
|
||||
result, err := cast.ToFloat32E(dataString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataString %s cast to float32 error: %s", dataString, err.Error())
|
||||
}
|
||||
reallyData[fieldName] = result
|
||||
case schemapb.DataType_Double:
|
||||
result, err := cast.ToFloat64E(dataString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dataString %s cast to float64 error: %s", dataString, err.Error())
|
||||
}
|
||||
reallyData[fieldName] = result
|
||||
case schemapb.DataType_VarChar:
|
||||
reallyData[fieldName] = dataString
|
||||
case schemapb.DataType_String:
|
||||
reallyData[fieldName] = dataString
|
||||
default:
|
||||
return fmt.Errorf("not support fieldName %s dataType %s", fieldName, fieldType)
|
||||
}
|
||||
}
|
||||
|
||||
// fill dynamic schema
|
||||
if collDescResp.Schema.EnableDynamicField {
|
||||
for mapKey, mapValue := range data.Map() {
|
||||
if !containsString(fieldNames, mapKey) {
|
||||
mapValueStr := mapValue.String()
|
||||
if mapValue.Type == gjson.True || mapValue.Type == gjson.False {
|
||||
reallyData[mapKey] = cast.ToBool(mapValueStr)
|
||||
} else if mapValue.Type == gjson.String {
|
||||
reallyData[mapKey] = mapValueStr
|
||||
} else if mapValue.Type == gjson.Number {
|
||||
if strings.Contains(mapValue.Raw, ".") {
|
||||
reallyData[mapKey] = cast.ToFloat64(mapValue.Raw)
|
||||
} else {
|
||||
reallyData[mapKey] = cast.ToInt64(mapValueStr)
|
||||
}
|
||||
} else if mapValue.Type == gjson.JSON {
|
||||
reallyData[mapKey] = mapValue.Value()
|
||||
} else {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reallyDataArray = append(reallyDataArray, reallyData)
|
||||
} else {
|
||||
return fmt.Errorf("dataType %s not Json", data.Type)
|
||||
}
|
||||
}
|
||||
req.Data = reallyDataArray
|
||||
return nil
|
||||
}
|
||||
|
||||
func containsString(arr []string, s string) bool {
|
||||
for _, str := range arr {
|
||||
if str == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getDim(field *schemapb.FieldSchema) (int64, error) {
|
||||
dimensionInSchema, err := funcutil.GetAttrByKeyFromRepeatedKV(common.DimKey, field.TypeParams)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
dim, err := strconv.Atoi(dimensionInSchema)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(dim), nil
|
||||
}
|
||||
|
||||
func convertFloatVectorToArray(vector [][]float32, dim int64) ([]float32, error) {
|
||||
floatArray := make([]float32, 0)
|
||||
for _, arr := range vector {
|
||||
if int64(len(arr)) != dim {
|
||||
return nil, errors.New("vector length diff from dimension")
|
||||
}
|
||||
for i := int64(0); i < dim; i++ {
|
||||
floatArray = append(floatArray, arr[i])
|
||||
}
|
||||
}
|
||||
return floatArray, nil
|
||||
}
|
||||
|
||||
func convertBinaryVectorToArray(vector [][]byte, dim int64) ([]byte, error) {
|
||||
binaryArray := make([]byte, 0)
|
||||
bytesLen := dim / 8
|
||||
for _, arr := range vector {
|
||||
if int64(len(arr)) != bytesLen {
|
||||
return nil, errors.New("vector length diff from dimension")
|
||||
}
|
||||
for i := int64(0); i < bytesLen; i++ {
|
||||
binaryArray = append(binaryArray, arr[i])
|
||||
}
|
||||
}
|
||||
return binaryArray, nil
|
||||
}
|
||||
|
||||
type fieldCandi struct {
|
||||
name string
|
||||
v reflect.Value
|
||||
options map[string]string
|
||||
}
|
||||
|
||||
func reflectValueCandi(v reflect.Value) (map[string]fieldCandi, error) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
result := make(map[string]fieldCandi)
|
||||
switch v.Kind() {
|
||||
case reflect.Map: // map[string]interface{}
|
||||
iter := v.MapRange()
|
||||
for iter.Next() {
|
||||
key := iter.Key().String()
|
||||
result[key] = fieldCandi{
|
||||
name: key,
|
||||
v: iter.Value(),
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport row type: %s", v.Kind().String())
|
||||
}
|
||||
}
|
||||
|
||||
func convertToIntArray(dataType schemapb.DataType, arr interface{}) []int32 {
|
||||
var res []int32
|
||||
switch dataType {
|
||||
case schemapb.DataType_Int8:
|
||||
for _, num := range arr.([]int8) {
|
||||
res = append(res, int32(num))
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
for _, num := range arr.([]int16) {
|
||||
res = append(res, int32(num))
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func anyToColumns(rows []map[string]interface{}, sch *schemapb.CollectionSchema) ([]*schemapb.FieldData, error) {
|
||||
rowsLen := len(rows)
|
||||
if rowsLen == 0 {
|
||||
return []*schemapb.FieldData{}, errors.New("0 length column")
|
||||
}
|
||||
|
||||
isDynamic := sch.EnableDynamicField
|
||||
var dim int64
|
||||
|
||||
nameColumns := make(map[string]interface{})
|
||||
fieldData := make(map[string]*schemapb.FieldData)
|
||||
for _, field := range sch.Fields {
|
||||
// skip auto id pk field
|
||||
if field.IsPrimaryKey && field.AutoID {
|
||||
continue
|
||||
}
|
||||
var data interface{}
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
data = make([]bool, 0, rowsLen)
|
||||
case schemapb.DataType_Int8:
|
||||
data = make([]int8, 0, rowsLen)
|
||||
case schemapb.DataType_Int16:
|
||||
data = make([]int16, 0, rowsLen)
|
||||
case schemapb.DataType_Int32:
|
||||
data = make([]int32, 0, rowsLen)
|
||||
case schemapb.DataType_Int64:
|
||||
data = make([]int64, 0, rowsLen)
|
||||
case schemapb.DataType_Float:
|
||||
data = make([]float32, 0, rowsLen)
|
||||
case schemapb.DataType_Double:
|
||||
data = make([]float64, 0, rowsLen)
|
||||
case schemapb.DataType_String:
|
||||
data = make([]string, 0, rowsLen)
|
||||
case schemapb.DataType_VarChar:
|
||||
data = make([]string, 0, rowsLen)
|
||||
case schemapb.DataType_JSON:
|
||||
data = make([][]byte, 0, rowsLen)
|
||||
case schemapb.DataType_FloatVector:
|
||||
data = make([][]float32, 0, rowsLen)
|
||||
dim, _ = getDim(field)
|
||||
case schemapb.DataType_BinaryVector:
|
||||
data = make([][]byte, 0, rowsLen)
|
||||
dim, _ = getDim(field)
|
||||
default:
|
||||
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
|
||||
}
|
||||
nameColumns[field.Name] = data
|
||||
fieldData[field.Name] = &schemapb.FieldData{
|
||||
Type: field.DataType,
|
||||
FieldName: field.Name,
|
||||
FieldId: field.FieldID,
|
||||
IsDynamic: field.IsDynamic,
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
return nil, errors.New("cannot find dimension")
|
||||
}
|
||||
|
||||
dynamicCol := make([][]byte, 0, rowsLen)
|
||||
|
||||
for _, row := range rows {
|
||||
// collection schema name need not be same, since receiver could have other names
|
||||
v := reflect.ValueOf(row)
|
||||
set, err := reflectValueCandi(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for idx, field := range sch.Fields {
|
||||
// skip auto id pk field
|
||||
if field.IsPrimaryKey && field.AutoID {
|
||||
// remove pk field from candidates set, avoid adding it into dynamic column
|
||||
delete(set, field.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
candi, ok := set[field.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
|
||||
}
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]bool), candi.v.Interface().(bool))
|
||||
case schemapb.DataType_Int8:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]int8), candi.v.Interface().(int8))
|
||||
case schemapb.DataType_Int16:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]int16), candi.v.Interface().(int16))
|
||||
case schemapb.DataType_Int32:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]int32), candi.v.Interface().(int32))
|
||||
case schemapb.DataType_Int64:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]int64), candi.v.Interface().(int64))
|
||||
case schemapb.DataType_Float:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]float32), candi.v.Interface().(float32))
|
||||
case schemapb.DataType_Double:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]float64), candi.v.Interface().(float64))
|
||||
case schemapb.DataType_String:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]string), candi.v.Interface().(string))
|
||||
case schemapb.DataType_VarChar:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]string), candi.v.Interface().(string))
|
||||
case schemapb.DataType_JSON:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
|
||||
case schemapb.DataType_FloatVector:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([][]float32), candi.v.Interface().([]float32))
|
||||
case schemapb.DataType_BinaryVector:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
|
||||
default:
|
||||
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
|
||||
}
|
||||
|
||||
delete(set, field.Name)
|
||||
}
|
||||
|
||||
if isDynamic {
|
||||
m := make(map[string]interface{})
|
||||
for name, candi := range set {
|
||||
m[name] = candi.v.Interface()
|
||||
}
|
||||
bs, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal dynamic field %w", err)
|
||||
}
|
||||
dynamicCol = append(dynamicCol, bs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to append value to dynamic field %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
columns := make([]*schemapb.FieldData, 0, len(nameColumns))
|
||||
for name, column := range nameColumns {
|
||||
colData := fieldData[name]
|
||||
switch colData.Type {
|
||||
case schemapb.DataType_Bool:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: column.([]bool),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int8:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: convertToIntArray(colData.Type, column),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int16:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: convertToIntArray(colData.Type, column),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: column.([]int32),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: column.([]int64),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: column.([]float32),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: column.([]float64),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_String:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: column.([]string),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: column.([]string),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_JSON:
|
||||
colData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BytesData{
|
||||
BytesData: &schemapb.BytesArray{
|
||||
Data: column.([][]byte),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
arr, err := convertFloatVectorToArray(column.([][]float32), dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
colData.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: arr,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
arr, err := convertBinaryVectorToArray(column.([][]byte), dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
colData.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: arr,
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", colData.Type, name)
|
||||
}
|
||||
columns = append(columns, colData)
|
||||
}
|
||||
if isDynamic {
|
||||
columns = append(columns, &schemapb.FieldData{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: dynamicCol,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: true,
|
||||
})
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// --------------------- search param --------------------- //
|
||||
func serialize(fv []float32) []byte {
|
||||
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
|
||||
buf := make([]byte, 4)
|
||||
for _, f := range fv {
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
|
||||
data = append(data, buf...)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func vector2PlaceholderGroupBytes(vectors []float32) []byte {
|
||||
var placeHolderType commonpb.PlaceholderType
|
||||
ph := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Values: make([][]byte, 0, len(vectors)),
|
||||
}
|
||||
if len(vectors) != 0 {
|
||||
placeHolderType = commonpb.PlaceholderType_FloatVector
|
||||
|
||||
ph.Type = placeHolderType
|
||||
ph.Values = append(ph.Values, serialize(vectors))
|
||||
}
|
||||
phg := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
ph,
|
||||
},
|
||||
}
|
||||
|
||||
bs, _ := proto.Marshal(phg)
|
||||
return bs
|
||||
}
|
||||
|
||||
// --------------------- get/query/search response --------------------- //
|
||||
func genDynamicFields(fields []string, list []*schemapb.FieldData) []string {
|
||||
nonDynamicFieldNames := make(map[string]struct{})
|
||||
for _, field := range list {
|
||||
if !field.IsDynamic {
|
||||
nonDynamicFieldNames[field.FieldName] = struct{}{}
|
||||
}
|
||||
}
|
||||
dynamicFields := []string{}
|
||||
for _, fieldName := range fields {
|
||||
if _, exist := nonDynamicFieldNames[fieldName]; !exist {
|
||||
dynamicFields = append(dynamicFields, fieldName)
|
||||
}
|
||||
}
|
||||
return dynamicFields
|
||||
}
|
||||
|
||||
func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemapb.FieldData, ids *schemapb.IDs, scores []float32) ([]map[string]interface{}, error) {
|
||||
var queryResp []map[string]interface{}
|
||||
|
||||
columnNum := len(fieldDataList)
|
||||
if rowsNum == int64(0) {
|
||||
if columnNum > 0 {
|
||||
switch fieldDataList[0].Type {
|
||||
case schemapb.DataType_Bool:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetBoolData().Data))
|
||||
case schemapb.DataType_Int8:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetIntData().Data))
|
||||
case schemapb.DataType_Int16:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetIntData().Data))
|
||||
case schemapb.DataType_Int32:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetIntData().Data))
|
||||
case schemapb.DataType_Int64:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetLongData().Data))
|
||||
case schemapb.DataType_Float:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetFloatData().Data))
|
||||
case schemapb.DataType_Double:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetDoubleData().Data))
|
||||
case schemapb.DataType_String:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetStringData().Data))
|
||||
case schemapb.DataType_VarChar:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetStringData().Data))
|
||||
case schemapb.DataType_JSON:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetJsonData().Data))
|
||||
case schemapb.DataType_Array:
|
||||
rowsNum = int64(len(fieldDataList[0].GetScalars().GetArrayData().Data))
|
||||
case schemapb.DataType_BinaryVector:
|
||||
rowsNum = int64(len(fieldDataList[0].GetVectors().GetBinaryVector())*8) / fieldDataList[0].GetVectors().GetDim()
|
||||
case schemapb.DataType_FloatVector:
|
||||
rowsNum = int64(len(fieldDataList[0].GetVectors().GetFloatVector().Data)) / fieldDataList[0].GetVectors().GetDim()
|
||||
default:
|
||||
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", fieldDataList[0].Type, fieldDataList[0].FieldName)
|
||||
}
|
||||
} else if ids != nil {
|
||||
switch ids.IdField.(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
int64Pks := ids.GetIntId().GetData()
|
||||
rowsNum = int64(len(int64Pks))
|
||||
case *schemapb.IDs_StrId:
|
||||
stringPks := ids.GetStrId().GetData()
|
||||
rowsNum = int64(len(stringPks))
|
||||
default:
|
||||
return nil, fmt.Errorf("the type of primary key(id) is not supported, use other sdk please")
|
||||
}
|
||||
}
|
||||
}
|
||||
if rowsNum == int64(0) {
|
||||
return []map[string]interface{}{}, nil
|
||||
}
|
||||
dynamicOutputFields := genDynamicFields(needFields, fieldDataList)
|
||||
for i := int64(0); i < rowsNum; i++ {
|
||||
row := map[string]interface{}{}
|
||||
if columnNum > 0 {
|
||||
for j := 0; j < columnNum; j++ {
|
||||
switch fieldDataList[j].Type {
|
||||
case schemapb.DataType_Bool:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetBoolData().Data[i]
|
||||
case schemapb.DataType_Int8:
|
||||
row[fieldDataList[j].FieldName] = int8(fieldDataList[j].GetScalars().GetIntData().Data[i])
|
||||
case schemapb.DataType_Int16:
|
||||
row[fieldDataList[j].FieldName] = int16(fieldDataList[j].GetScalars().GetIntData().Data[i])
|
||||
case schemapb.DataType_Int32:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetIntData().Data[i]
|
||||
case schemapb.DataType_Int64:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetLongData().Data[i]
|
||||
case schemapb.DataType_Float:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetFloatData().Data[i]
|
||||
case schemapb.DataType_Double:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetDoubleData().Data[i]
|
||||
case schemapb.DataType_String:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetStringData().Data[i]
|
||||
case schemapb.DataType_VarChar:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetStringData().Data[i]
|
||||
case schemapb.DataType_BinaryVector:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetBinaryVector()[i*(fieldDataList[j].GetVectors().GetDim()/8) : (i+1)*(fieldDataList[j].GetVectors().GetDim()/8)]
|
||||
case schemapb.DataType_FloatVector:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetFloatVector().Data[i*fieldDataList[j].GetVectors().GetDim() : (i+1)*fieldDataList[j].GetVectors().GetDim()]
|
||||
case schemapb.DataType_Array:
|
||||
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetArrayData().Data[i]
|
||||
case schemapb.DataType_JSON:
|
||||
data, ok := fieldDataList[j].GetScalars().Data.(*schemapb.ScalarField_JsonData)
|
||||
if ok && !fieldDataList[j].IsDynamic {
|
||||
row[fieldDataList[j].FieldName] = data.JsonData.Data[i]
|
||||
} else {
|
||||
var dataMap map[string]interface{}
|
||||
|
||||
err := json.Unmarshal(fieldDataList[j].GetScalars().GetJsonData().Data[i], &dataMap)
|
||||
if err != nil {
|
||||
log.Error(fmt.Sprintf("[BuildQueryResp] Unmarshal error %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if containsString(dynamicOutputFields, fieldDataList[j].FieldName) {
|
||||
for key, value := range dataMap {
|
||||
row[key] = value
|
||||
}
|
||||
} else {
|
||||
for _, dynamicField := range dynamicOutputFields {
|
||||
if _, ok := dataMap[dynamicField]; ok {
|
||||
row[dynamicField] = dataMap[dynamicField]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
default:
|
||||
row[fieldDataList[j].FieldName] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
if ids != nil {
|
||||
switch ids.IdField.(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
int64Pks := ids.GetIntId().GetData()
|
||||
row[DefaultPrimaryFieldName] = int64Pks[i]
|
||||
case *schemapb.IDs_StrId:
|
||||
stringPks := ids.GetStrId().GetData()
|
||||
row[DefaultPrimaryFieldName] = stringPks[i]
|
||||
default:
|
||||
return nil, fmt.Errorf("the type of primary key(id) is not supported, use other sdk please")
|
||||
}
|
||||
}
|
||||
if scores != nil && int64(len(scores)) > i {
|
||||
row[HTTPReturnDistance] = scores[i]
|
||||
}
|
||||
queryResp = append(queryResp, row)
|
||||
}
|
||||
|
||||
return queryResp, nil
|
||||
}
|
||||
|
||||
// --------------------- error code --------------------- //
|
||||
func code(code int32) int32 {
|
||||
return code & merr.RootReasonCodeMask
|
||||
}
|
||||
|
||||
func Code(err error) int32 {
|
||||
return code(merr.Code(err))
|
||||
}
|
|
@ -0,0 +1,808 @@
|
|||
package httpserver
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
FieldWordCount = "word_count"
|
||||
FieldBookID = "book_id"
|
||||
FieldBookIntro = "book_intro"
|
||||
)
|
||||
|
||||
func generatePrimaryField(datatype schemapb.DataType) schemapb.FieldSchema {
|
||||
return schemapb.FieldSchema{
|
||||
FieldID: common.StartOfUserFieldID,
|
||||
Name: FieldBookID,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: datatype,
|
||||
AutoID: false,
|
||||
}
|
||||
}
|
||||
|
||||
func generateIds(num int) *schemapb.IDs {
|
||||
var intArray []int64
|
||||
if num == 0 {
|
||||
intArray = []int64{}
|
||||
} else {
|
||||
for i := int64(1); i < int64(num+1); i++ {
|
||||
intArray = append(intArray, i)
|
||||
}
|
||||
}
|
||||
return &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: intArray,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func generateVectorFieldSchema(useBinary bool) schemapb.FieldSchema {
|
||||
if useBinary {
|
||||
return schemapb.FieldSchema{
|
||||
FieldID: common.StartOfUserFieldID + 2,
|
||||
Name: "field-binary",
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: 100,
|
||||
AutoID: false,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return schemapb.FieldSchema{
|
||||
FieldID: common.StartOfUserFieldID + 2,
|
||||
Name: FieldBookIntro,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: 101,
|
||||
AutoID: false,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "2",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func generateCollectionSchema(useBinary bool) *schemapb.CollectionSchema {
|
||||
primaryField := generatePrimaryField(schemapb.DataType_Int64)
|
||||
vectorField := generateVectorFieldSchema(useBinary)
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: DefaultCollectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&primaryField, {
|
||||
FieldID: common.StartOfUserFieldID + 1,
|
||||
Name: FieldWordCount,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: 5,
|
||||
AutoID: false,
|
||||
}, &vectorField,
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
}
|
||||
}
|
||||
|
||||
func generateIndexes() []*milvuspb.IndexDescription {
|
||||
return []*milvuspb.IndexDescription{
|
||||
{
|
||||
IndexName: DefaultIndexName,
|
||||
IndexID: 442051985533243300,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: DefaultMetricType,
|
||||
},
|
||||
{
|
||||
Key: "index_type",
|
||||
Value: "IVF_FLAT",
|
||||
}, {
|
||||
Key: Params,
|
||||
Value: "{\"nlist\":1024}",
|
||||
},
|
||||
},
|
||||
State: 3,
|
||||
FieldName: FieldBookIntro,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func generateVectorFieldData(useBinary bool) schemapb.FieldData {
|
||||
if useBinary {
|
||||
return schemapb.FieldData{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: "field-binary",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 8,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: []byte{byte(0), byte(1), byte(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
}
|
||||
return schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: FieldBookIntro,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 2,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{0.1, 0.11, 0.2, 0.22, 0.3, 0.33},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
}
|
||||
|
||||
func generateFieldData() []*schemapb.FieldData {
|
||||
fieldData1 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: FieldBookID,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
|
||||
fieldData2 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: FieldWordCount,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1000, 2000, 3000},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
|
||||
fieldData3 := generateVectorFieldData(false)
|
||||
return []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
|
||||
}
|
||||
|
||||
func generateSearchResult() []map[string]interface{} {
|
||||
row1 := map[string]interface{}{
|
||||
DefaultPrimaryFieldName: int64(1),
|
||||
FieldBookID: int64(1),
|
||||
FieldWordCount: int64(1000),
|
||||
FieldBookIntro: []float32{0.1, 0.11},
|
||||
HTTPReturnDistance: float32(0.01),
|
||||
}
|
||||
row2 := map[string]interface{}{
|
||||
DefaultPrimaryFieldName: int64(2),
|
||||
FieldBookID: int64(2),
|
||||
FieldWordCount: int64(2000),
|
||||
FieldBookIntro: []float32{0.2, 0.22},
|
||||
HTTPReturnDistance: float32(0.04),
|
||||
}
|
||||
row3 := map[string]interface{}{
|
||||
DefaultPrimaryFieldName: int64(3),
|
||||
FieldBookID: int64(3),
|
||||
FieldWordCount: int64(3000),
|
||||
FieldBookIntro: []float32{0.3, 0.33},
|
||||
HTTPReturnDistance: float32(0.09),
|
||||
}
|
||||
return []map[string]interface{}{row1, row2, row3}
|
||||
}
|
||||
|
||||
func generateQueryResult64(withDistance bool) []map[string]interface{} {
|
||||
row1 := map[string]interface{}{
|
||||
FieldBookID: float64(1),
|
||||
FieldWordCount: float64(1000),
|
||||
FieldBookIntro: []float64{0.1, 0.11},
|
||||
}
|
||||
row2 := map[string]interface{}{
|
||||
FieldBookID: float64(2),
|
||||
FieldWordCount: float64(2000),
|
||||
FieldBookIntro: []float64{0.2, 0.22},
|
||||
}
|
||||
row3 := map[string]interface{}{
|
||||
FieldBookID: float64(3),
|
||||
FieldWordCount: float64(3000),
|
||||
FieldBookIntro: []float64{0.3, 0.33},
|
||||
}
|
||||
if withDistance {
|
||||
row1[HTTPReturnDistance] = float64(0.01)
|
||||
row2[HTTPReturnDistance] = float64(0.04)
|
||||
row3[HTTPReturnDistance] = float64(0.09)
|
||||
}
|
||||
return []map[string]interface{}{row1, row2, row3}
|
||||
}
|
||||
|
||||
func TestPrintCollectionDetails(t *testing.T) {
|
||||
coll := generateCollectionSchema(false)
|
||||
indexes := generateIndexes()
|
||||
assert.Equal(t, printFields(coll.Fields), []gin.H{
|
||||
{
|
||||
HTTPReturnFieldName: FieldBookID,
|
||||
HTTPReturnFieldType: "Int64",
|
||||
HTTPReturnFieldPrimaryKey: true,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: ""},
|
||||
{
|
||||
HTTPReturnFieldName: FieldWordCount,
|
||||
HTTPReturnFieldType: "Int64",
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: ""},
|
||||
{
|
||||
HTTPReturnFieldName: FieldBookIntro,
|
||||
HTTPReturnFieldType: "FloatVector(2)",
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: ""},
|
||||
})
|
||||
assert.Equal(t, printIndexes(indexes), []gin.H{
|
||||
{
|
||||
HTTPReturnIndexName: DefaultIndexName,
|
||||
HTTPReturnIndexField: FieldBookIntro,
|
||||
HTTPReturnIndexMetricsType: DefaultMetricType},
|
||||
})
|
||||
assert.Equal(t, getMetricType(indexes[0].Params), DefaultMetricType)
|
||||
assert.Equal(t, getMetricType(nil), DefaultMetricType)
|
||||
fields := []*schemapb.FieldSchema{}
|
||||
for _, field := range newCollectionSchema(coll).Fields {
|
||||
if field.DataType == schemapb.DataType_VarChar {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, printFields(fields), []gin.H{
|
||||
{
|
||||
HTTPReturnFieldName: "field-varchar",
|
||||
HTTPReturnFieldType: "VarChar(10)",
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: ""},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrimaryField(t *testing.T) {
|
||||
coll := generateCollectionSchema(false)
|
||||
primaryField := generatePrimaryField(schemapb.DataType_Int64)
|
||||
field, ok := getPrimaryField(coll)
|
||||
assert.Equal(t, ok, true)
|
||||
assert.Equal(t, *field, primaryField)
|
||||
|
||||
assert.Equal(t, joinArray([]int64{1, 2, 3}), "1,2,3")
|
||||
assert.Equal(t, joinArray([]string{"1", "2", "3"}), "1,2,3")
|
||||
|
||||
jsonStr := "{\"id\": [1, 2, 3]}"
|
||||
idStr := gjson.Get(jsonStr, "id")
|
||||
rangeStr, err := convertRange(&primaryField, idStr)
|
||||
assert.Equal(t, err, nil)
|
||||
assert.Equal(t, rangeStr, "1,2,3")
|
||||
filter, err := checkGetPrimaryKey(coll, idStr)
|
||||
assert.Equal(t, err, nil)
|
||||
assert.Equal(t, filter, "book_id in [1,2,3]")
|
||||
|
||||
primaryField = generatePrimaryField(schemapb.DataType_VarChar)
|
||||
jsonStr = "{\"id\": [\"1\", \"2\", \"3\"]}"
|
||||
idStr = gjson.Get(jsonStr, "id")
|
||||
rangeStr, err = convertRange(&primaryField, idStr)
|
||||
assert.Equal(t, err, nil)
|
||||
assert.Equal(t, rangeStr, "1,2,3")
|
||||
filter, err = checkGetPrimaryKey(coll, idStr)
|
||||
assert.Equal(t, err, nil)
|
||||
assert.Equal(t, filter, "book_id in [1,2,3]")
|
||||
}
|
||||
|
||||
func TestInsertWithDynamicFields(t *testing.T) {
|
||||
body := "{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2}}"
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(false)
|
||||
err := checkAndSetData(body, &milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Schema: coll,
|
||||
}, &req)
|
||||
assert.Equal(t, err, nil)
|
||||
assert.Equal(t, req.Data[0]["id"], int64(0))
|
||||
assert.Equal(t, req.Data[0]["book_id"], int64(1))
|
||||
assert.Equal(t, req.Data[0]["word_count"], int64(2))
|
||||
fieldsData, err := anyToColumns(req.Data, coll)
|
||||
assert.Equal(t, err, nil)
|
||||
assert.Equal(t, fieldsData[len(fieldsData)-1].IsDynamic, true)
|
||||
assert.Equal(t, fieldsData[len(fieldsData)-1].Type, schemapb.DataType_JSON)
|
||||
assert.Equal(t, string(fieldsData[len(fieldsData)-1].GetScalars().GetJsonData().GetData()[0]), "{\"id\":0}")
|
||||
}
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
parameters := []float32{0.11111, 0.22222}
|
||||
//assert.Equal(t, string(serialize(parameters)), "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e")
|
||||
//assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "vector2PlaceholderGroupBytes") // todo
|
||||
assert.Equal(t, string(serialize(parameters)), "\xa4\x8d\xe3=\xa4\x8dc>")
|
||||
assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>") // todo
|
||||
}
|
||||
|
||||
func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool {
|
||||
for key, value := range m1 {
|
||||
if key == FieldBookIntro {
|
||||
arr1 := value.([]interface{})
|
||||
arr2 := m2[key].([]float64)
|
||||
if len(arr1) != len(arr2) {
|
||||
return false
|
||||
}
|
||||
for j, element := range arr1 {
|
||||
if element != arr2[j] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else if value != m2[key] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range m2 {
|
||||
if key == FieldBookIntro {
|
||||
continue
|
||||
} else if value != m1[key] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
func compareRow(m1 map[string]interface{}, m2 map[string]interface{}) bool {
|
||||
for key, value := range m1 {
|
||||
if key == FieldBookIntro {
|
||||
arr1 := value.([]float32)
|
||||
arr2 := m2[key].([]float32)
|
||||
if len(arr1) != len(arr2) {
|
||||
return false
|
||||
}
|
||||
for j, element := range arr1 {
|
||||
if element != arr2[j] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else if (key == "field-binary") || (key == "field-json") {
|
||||
arr1 := value.([]byte)
|
||||
arr2 := m2[key].([]byte)
|
||||
if len(arr1) != len(arr2) {
|
||||
return false
|
||||
}
|
||||
for j, element := range arr1 {
|
||||
if element != arr2[j] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else if value != m2[key] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range m2 {
|
||||
if (key == FieldBookIntro) || (key == "field-binary") || (key == "field-json") {
|
||||
continue
|
||||
} else if value != m1[key] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
type CompareFunc func(map[string]interface{}, map[string]interface{}) bool
|
||||
|
||||
func compareRows(row1 []map[string]interface{}, row2 []map[string]interface{}, compareFunc CompareFunc) bool {
|
||||
if len(row1) != len(row2) {
|
||||
return false
|
||||
}
|
||||
for i, row := range row1 {
|
||||
if !compareFunc(row, row2[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestBuildQueryResp(t *testing.T) {
|
||||
outputFields := []string{FieldBookID, FieldWordCount, "author", "date"}
|
||||
rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(3), []float32{0.01, 0.04, 0.09}) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3}
|
||||
assert.Equal(t, err, nil)
|
||||
exceptRows := generateSearchResult()
|
||||
assert.Equal(t, compareRows(rows, exceptRows, compareRow), true)
|
||||
}
|
||||
|
||||
func newCollectionSchema(coll *schemapb.CollectionSchema) *schemapb.CollectionSchema {
|
||||
fieldSchema1 := schemapb.FieldSchema{
|
||||
Name: "field-bool",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema1)
|
||||
|
||||
fieldSchema2 := schemapb.FieldSchema{
|
||||
Name: "field-int8",
|
||||
DataType: schemapb.DataType_Int8,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema2)
|
||||
|
||||
fieldSchema3 := schemapb.FieldSchema{
|
||||
Name: "field-int16",
|
||||
DataType: schemapb.DataType_Int16,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema3)
|
||||
|
||||
fieldSchema4 := schemapb.FieldSchema{
|
||||
Name: "field-int32",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema4)
|
||||
|
||||
fieldSchema5 := schemapb.FieldSchema{
|
||||
Name: "field-float",
|
||||
DataType: schemapb.DataType_Float,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema5)
|
||||
|
||||
fieldSchema6 := schemapb.FieldSchema{
|
||||
Name: "field-double",
|
||||
DataType: schemapb.DataType_Double,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema6)
|
||||
|
||||
fieldSchema7 := schemapb.FieldSchema{
|
||||
Name: "field-string",
|
||||
DataType: schemapb.DataType_String,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema7)
|
||||
|
||||
fieldSchema8 := schemapb.FieldSchema{
|
||||
Name: "field-varchar",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "max_length", Value: "10"},
|
||||
},
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema8)
|
||||
|
||||
fieldSchema9 := schemapb.FieldSchema{
|
||||
Name: "field-json",
|
||||
DataType: schemapb.DataType_JSON,
|
||||
IsDynamic: false,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema9)
|
||||
|
||||
//fieldSchema10 := schemapb.FieldSchema{
|
||||
// Name: "$meta",
|
||||
// DataType: schemapb.DataType_JSON,
|
||||
// IsDynamic: true,
|
||||
//}
|
||||
//coll.Fields = append(coll.Fields, &fieldSchema10)
|
||||
|
||||
return coll
|
||||
}
|
||||
|
||||
func withUnsupportField(coll *schemapb.CollectionSchema) *schemapb.CollectionSchema {
|
||||
fieldSchema10 := schemapb.FieldSchema{
|
||||
Name: "field-array",
|
||||
DataType: schemapb.DataType_Array,
|
||||
IsDynamic: false,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema10)
|
||||
|
||||
return coll
|
||||
}
|
||||
|
||||
func withDynamicField(coll *schemapb.CollectionSchema) *schemapb.CollectionSchema {
|
||||
fieldSchema11 := schemapb.FieldSchema{
|
||||
Name: "$meta",
|
||||
DataType: schemapb.DataType_JSON,
|
||||
IsDynamic: true,
|
||||
}
|
||||
coll.Fields = append(coll.Fields, &fieldSchema11)
|
||||
|
||||
return coll
|
||||
}
|
||||
|
||||
func newFieldData(fieldDatas []*schemapb.FieldData, firstFieldType schemapb.DataType) []*schemapb.FieldData {
|
||||
fieldData1 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
FieldName: "field-bool",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData1)
|
||||
|
||||
fieldData2 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
FieldName: "field-int8",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{0, 1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData2)
|
||||
|
||||
fieldData3 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int16,
|
||||
FieldName: "field-int16",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{0, 1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData3)
|
||||
|
||||
fieldData4 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int32,
|
||||
FieldName: "field-int32",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: []int32{0, 1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData4)
|
||||
|
||||
fieldData5 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: "field-float",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{0, 1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData5)
|
||||
|
||||
fieldData6 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Double,
|
||||
FieldName: "field-double",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: []float64{0, 1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData6)
|
||||
|
||||
fieldData7 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_String,
|
||||
FieldName: "field-string",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"0", "1", "2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData7)
|
||||
|
||||
fieldData8 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldName: "field-varchar",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"0", "1", "2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData8)
|
||||
|
||||
fieldData9 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "field-json",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{[]byte(`{"XXX": 0}`), []byte(`{"XXX": 0}`), []byte(`{"XXX": 0}`)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData9)
|
||||
|
||||
fieldData10 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: "field-array",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: []*schemapb.ScalarField{
|
||||
{Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: []bool{true}}}},
|
||||
{Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: []bool{true}}}},
|
||||
{Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: []bool{true}}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: false,
|
||||
}
|
||||
|
||||
fieldData11 := schemapb.FieldData{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "$meta",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{[]byte(`{"XXX": 0, "YYY": "0"}`), []byte(`{"XXX": 1, "YYY": "1"}`), []byte(`{"XXX": 2, "YYY": "2"}`)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: true,
|
||||
}
|
||||
fieldDatas = append(fieldDatas, &fieldData11)
|
||||
|
||||
switch firstFieldType {
|
||||
case schemapb.DataType_None:
|
||||
break
|
||||
case schemapb.DataType_Bool:
|
||||
return []*schemapb.FieldData{&fieldData1}
|
||||
case schemapb.DataType_Int8:
|
||||
return []*schemapb.FieldData{&fieldData2}
|
||||
case schemapb.DataType_Int16:
|
||||
return []*schemapb.FieldData{&fieldData3}
|
||||
case schemapb.DataType_Int32:
|
||||
return []*schemapb.FieldData{&fieldData4}
|
||||
case schemapb.DataType_Float:
|
||||
return []*schemapb.FieldData{&fieldData5}
|
||||
case schemapb.DataType_Double:
|
||||
return []*schemapb.FieldData{&fieldData6}
|
||||
case schemapb.DataType_String:
|
||||
return []*schemapb.FieldData{&fieldData7}
|
||||
case schemapb.DataType_VarChar:
|
||||
return []*schemapb.FieldData{&fieldData8}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
vectorField := generateVectorFieldData(true)
|
||||
return []*schemapb.FieldData{&vectorField}
|
||||
case schemapb.DataType_FloatVector:
|
||||
vectorField := generateVectorFieldData(false)
|
||||
return []*schemapb.FieldData{&vectorField}
|
||||
case schemapb.DataType_Array:
|
||||
return []*schemapb.FieldData{&fieldData10}
|
||||
case schemapb.DataType_JSON:
|
||||
return []*schemapb.FieldData{&fieldData9}
|
||||
default:
|
||||
return []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "wrong-field-type",
|
||||
Type: firstFieldType,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return fieldDatas
|
||||
}
|
||||
|
||||
func newSearchResult(results []map[string]interface{}) []map[string]interface{} {
|
||||
for i, result := range results {
|
||||
result["field-bool"] = true
|
||||
result["field-int8"] = int8(i)
|
||||
result["field-int16"] = int16(i)
|
||||
result["field-int32"] = int32(i)
|
||||
result["field-float"] = float32(i)
|
||||
result["field-double"] = float64(i)
|
||||
result["field-varchar"] = strconv.Itoa(i)
|
||||
result["field-string"] = strconv.Itoa(i)
|
||||
result["field-binary"] = []byte{byte(i)}
|
||||
result["field-json"] = []byte(`{"XXX": 0}`)
|
||||
result["XXX"] = float64(i)
|
||||
result["YYY"] = strconv.Itoa(i)
|
||||
results[i] = result
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func TestAnyToColumn(t *testing.T) {
|
||||
data, err := anyToColumns(newSearchResult(generateSearchResult()), newCollectionSchema(generateCollectionSchema(false)))
|
||||
assert.Equal(t, err, nil)
|
||||
assert.Equal(t, len(data), 13)
|
||||
}
|
||||
|
||||
func TestBuildQueryResps(t *testing.T) {
|
||||
outputFields := []string{"XXX", "YYY"}
|
||||
outputFieldsList := [][]string{outputFields, {"$meta"}, {"$meta", FieldBookID, FieldBookIntro, "YYY"}}
|
||||
for _, theOutputFields := range outputFieldsList {
|
||||
rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(3), []float32{0.01, 0.04, 0.09})
|
||||
assert.Equal(t, err, nil)
|
||||
exceptRows := newSearchResult(generateSearchResult())
|
||||
assert.Equal(t, compareRows(rows, exceptRows, compareRow), true)
|
||||
}
|
||||
|
||||
dataTypes := []schemapb.DataType{schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector,
|
||||
schemapb.DataType_Bool, schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32,
|
||||
schemapb.DataType_Float, schemapb.DataType_Double,
|
||||
schemapb.DataType_String, schemapb.DataType_VarChar,
|
||||
schemapb.DataType_JSON, schemapb.DataType_Array}
|
||||
for _, dateType := range dataTypes {
|
||||
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIds(3), []float32{0.01, 0.04, 0.09})
|
||||
assert.Equal(t, err, nil)
|
||||
}
|
||||
|
||||
_, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIds(3), []float32{0.01, 0.04, 0.09})
|
||||
assert.Equal(t, err.Error(), "the type(1000) of field(wrong-field-type) is not supported, use other sdk please")
|
||||
|
||||
res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(3), []float32{0.01, 0.04, 0.09})
|
||||
assert.Equal(t, len(res), 3)
|
||||
assert.Equal(t, err, nil)
|
||||
|
||||
// len(rows) != len(scores), didn't show distance
|
||||
_, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(3), []float32{0.01, 0.04})
|
||||
assert.Equal(t, err, nil)
|
||||
}
|
|
@ -30,12 +30,18 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
management "github.com/milvus-io/milvus/internal/http"
|
||||
"github.com/milvus-io/milvus/internal/proxy/accesslog"
|
||||
"github.com/milvus-io/milvus/internal/util/componentutil"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/tracer"
|
||||
"github.com/milvus-io/milvus/pkg/util/interceptor"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
"github.com/soheilhy/cmux"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -61,7 +67,6 @@ import (
|
|||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/status"
|
||||
|
@ -81,6 +86,10 @@ type Server struct {
|
|||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
proxy types.ProxyComponent
|
||||
httpListener net.Listener
|
||||
grpcListener net.Listener
|
||||
tcpServer cmux.CMux
|
||||
httpServer *http.Server
|
||||
grpcInternalServer *grpc.Server
|
||||
grpcExternalServer *grpc.Server
|
||||
|
||||
|
@ -105,6 +114,21 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
|
|||
return server, err
|
||||
}
|
||||
|
||||
func authenticate(c *gin.Context) {
|
||||
if !proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
return
|
||||
}
|
||||
username, password, ok := httpserver.ParseUsernamePassword(c)
|
||||
if ok {
|
||||
if proxy.PasswordVerify(c, username, password) {
|
||||
log.Debug("auth successful", zap.String("username", username))
|
||||
c.Set(httpserver.ContextUsername, username)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{httpserver.HTTPReturnCode: httpserver.Code(merr.ErrNeedAuthenticate), httpserver.HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
}
|
||||
|
||||
// registerHTTPServer register the http server, panic when failed
|
||||
func (s *Server) registerHTTPServer() {
|
||||
// (Embedded Milvus Only) Discard gin logs if logging is disabled.
|
||||
|
@ -117,10 +141,40 @@ func (s *Server) registerHTTPServer() {
|
|||
if !proxy.Params.HTTPCfg.DebugMode.GetAsBool() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
ginHandler := gin.Default()
|
||||
apiv1 := ginHandler.Group(apiPathPrefix)
|
||||
metricsGinHandler := gin.Default()
|
||||
apiv1 := metricsGinHandler.Group(apiPathPrefix)
|
||||
httpserver.NewHandlers(s.proxy).RegisterRoutesTo(apiv1)
|
||||
http.Handle("/", ginHandler)
|
||||
management.Register(&management.Handler{
|
||||
Path: "/",
|
||||
HandlerFunc: nil,
|
||||
Handler: metricsGinHandler.Handler(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) startHTTPServer(errChan chan error) {
|
||||
defer s.wg.Done()
|
||||
ginHandler := gin.Default()
|
||||
ginHandler.Use(func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, DELETE, OPTIONS, PATCH")
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
app := ginHandler.Group("/v1", authenticate)
|
||||
httpserver.NewHandlers(s.proxy).RegisterRoutesToV1(app)
|
||||
s.httpServer = &http.Server{Handler: ginHandler, ReadHeaderTimeout: time.Second}
|
||||
errChan <- nil
|
||||
if err := s.httpServer.Serve(s.httpListener); err != nil && err != cmux.ErrServerClosed {
|
||||
log.Error("start Proxy http server to listen failed", zap.Error(err))
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
log.Info("Proxy http server exited")
|
||||
}
|
||||
|
||||
func (s *Server) startInternalRPCServer(grpcInternalPort int, errChan chan error) {
|
||||
|
@ -146,15 +200,6 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
|
|||
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
|
||||
}
|
||||
|
||||
log.Debug("Proxy server listen on tcp", zap.Int("port", grpcPort))
|
||||
lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort))
|
||||
if err != nil {
|
||||
log.Warn("Proxy server failed to listen on", zap.Error(err), zap.Int("port", grpcPort))
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
log.Debug("Proxy server already listen on tcp", zap.Int("port", grpcPort))
|
||||
|
||||
limiter, err := s.proxy.GetRateLimiter()
|
||||
if err != nil {
|
||||
log.Error("Get proxy rate limiter failed", zap.Int("port", grpcPort), zap.Error(err))
|
||||
|
@ -185,6 +230,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
|
|||
if Params.TLSMode.GetAsInt() == 1 {
|
||||
creds, err := credentials.NewServerTLSFromFile(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue())
|
||||
if err != nil {
|
||||
log.Warn("proxy can't create creds", zap.Error(err))
|
||||
log.Warn("proxy can't create creds", zap.Error(err))
|
||||
errChan <- err
|
||||
return
|
||||
|
@ -228,11 +274,12 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
|
|||
zap.Any("enforcement policy", kaep),
|
||||
zap.Any("server parameters", kasp))
|
||||
|
||||
if err := s.grpcExternalServer.Serve(lis); err != nil {
|
||||
if err := s.grpcExternalServer.Serve(s.grpcListener); err != nil && err != cmux.ErrServerClosed {
|
||||
log.Error("failed to serve on Proxy's listener", zap.Error(err))
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
log.Info("Proxy external grpc server exited")
|
||||
}
|
||||
|
||||
func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
|
||||
|
@ -283,6 +330,7 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
|
|||
errChan <- err
|
||||
return
|
||||
}
|
||||
log.Info("Proxy internal grpc server exited")
|
||||
}
|
||||
|
||||
// Start start the Proxy Server
|
||||
|
@ -301,6 +349,17 @@ func (s *Server) Run() error {
|
|||
}
|
||||
log.Debug("start Proxy server done")
|
||||
|
||||
if s.tcpServer != nil {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
if err := s.tcpServer.Serve(); err != nil && err != cmux.ErrServerClosed {
|
||||
log.Warn("Proxy server for tcp port failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
log.Info("Proxy tcp server exited")
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -345,6 +404,80 @@ func (s *Server) init() error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
{
|
||||
log.Info("Proxy server listen on tcp", zap.Int("port", Params.Port.GetAsInt()))
|
||||
var lis net.Listener
|
||||
var listenErr error
|
||||
|
||||
log.Info("Proxy server already listen on tcp", zap.Int("port", Params.Port.GetAsInt()))
|
||||
lis, listenErr = net.Listen("tcp", ":"+strconv.Itoa(Params.Port.GetAsInt()))
|
||||
if listenErr != nil {
|
||||
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
|
||||
return err
|
||||
}
|
||||
|
||||
if HTTPParams.Enabled.GetAsBool() && Params.TLSMode.GetAsInt() == 0 &&
|
||||
(HTTPParams.Port.GetValue() == "" || HTTPParams.Port.GetAsInt() == Params.Port.GetAsInt()) {
|
||||
s.tcpServer = cmux.New(lis)
|
||||
s.grpcListener = s.tcpServer.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"))
|
||||
s.httpListener = s.tcpServer.Match(cmux.Any())
|
||||
} else {
|
||||
s.grpcListener = lis
|
||||
}
|
||||
|
||||
if HTTPParams.Enabled.GetAsBool() && HTTPParams.Port.GetValue() != "" && HTTPParams.Port.GetAsInt() != Params.Port.GetAsInt() {
|
||||
if Params.TLSMode.GetAsInt() == 0 {
|
||||
s.httpListener, listenErr = net.Listen("tcp", ":"+strconv.Itoa(HTTPParams.Port.GetAsInt()))
|
||||
if listenErr != nil {
|
||||
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
|
||||
return err
|
||||
}
|
||||
} else if Params.TLSMode.GetAsInt() == 1 {
|
||||
creds, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue())
|
||||
if err != nil {
|
||||
log.Error("proxy can't create creds", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
s.httpListener, listenErr = tls.Listen("tcp", ":"+strconv.Itoa(HTTPParams.Port.GetAsInt()), &tls.Config{
|
||||
Certificates: []tls.Certificate{creds},
|
||||
})
|
||||
if listenErr != nil {
|
||||
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
|
||||
return listenErr
|
||||
}
|
||||
} else if Params.TLSMode.GetAsInt() == 2 {
|
||||
cert, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue())
|
||||
if err != nil {
|
||||
log.Error("proxy cant load x509 key pair", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
certPool := x509.NewCertPool()
|
||||
rootBuf, err := ioutil.ReadFile(Params.CaPemPath.GetValue())
|
||||
if err != nil {
|
||||
log.Error("failed read ca pem", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if !certPool.AppendCertsFromPEM(rootBuf) {
|
||||
log.Warn("fail to append ca to cert")
|
||||
return fmt.Errorf("fail to append ca to cert")
|
||||
}
|
||||
|
||||
tlsConf := &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientCAs: certPool,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
s.httpListener, listenErr = tls.Listen("tcp", ":"+strconv.Itoa(HTTPParams.Port.GetAsInt()), tlsConf)
|
||||
if listenErr != nil {
|
||||
log.Error("Proxy server(grpc/http) failed to listen on", zap.Error(err), zap.Int("port", Params.Port.GetAsInt()))
|
||||
return listenErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
{
|
||||
s.startExternalRPCServer(Params.Port.GetAsInt(), errChan)
|
||||
if err := <-errChan; err != nil {
|
||||
|
@ -355,7 +488,7 @@ func (s *Server) init() error {
|
|||
|
||||
if HTTPParams.Enabled.GetAsBool() {
|
||||
registerHTTPHandlerOnce.Do(func() {
|
||||
log.Info("register http server of proxy")
|
||||
log.Info("register Proxy http server")
|
||||
s.registerHTTPServer()
|
||||
})
|
||||
}
|
||||
|
@ -474,6 +607,17 @@ func (s *Server) start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if s.httpListener != nil {
|
||||
log.Info("start Proxy http server")
|
||||
errChan := make(chan error, 1)
|
||||
s.wg.Add(1)
|
||||
go s.startHTTPServer(errChan)
|
||||
if err := <-errChan; err != nil {
|
||||
log.Error("failed to create http rpc server", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -496,9 +640,16 @@ func (s *Server) Stop() error {
|
|||
log.Debug("Graceful stop grpc internal server...")
|
||||
s.grpcInternalServer.GracefulStop()
|
||||
}
|
||||
if s.grpcExternalServer != nil {
|
||||
log.Debug("Graceful stop grpc external server...")
|
||||
if s.tcpServer != nil {
|
||||
log.Info("Graceful stop Proxy tcp server...")
|
||||
s.tcpServer.Close()
|
||||
} else if s.grpcExternalServer != nil {
|
||||
log.Info("Graceful stop grpc external server...")
|
||||
s.grpcExternalServer.GracefulStop()
|
||||
if s.httpServer != nil {
|
||||
log.Info("Graceful stop grpc http server...")
|
||||
s.httpServer.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
gracefulWg.Wait()
|
||||
|
|
|
@ -966,12 +966,14 @@ func waitForGrpcReady(opt *WaitOption) {
|
|||
ch <- err
|
||||
return
|
||||
}
|
||||
_, err = grpc.Dial(address, grpc.WithBlock(), grpc.WithTransportCredentials(creds))
|
||||
conn, err := grpc.Dial(address, grpc.WithBlock(), grpc.WithTransportCredentials(creds))
|
||||
ch <- err
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if _, err := grpc.Dial(address, grpc.WithBlock(), grpc.WithInsecure()); true {
|
||||
if conn, err := grpc.Dial(address, grpc.WithBlock(), grpc.WithInsecure()); true {
|
||||
ch <- err
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -1522,7 +1524,7 @@ func Test_NewServer_HTTPServer_Enabled(t *testing.T) {
|
|||
t.Fatalf("test should have panicked but did not")
|
||||
}
|
||||
}()
|
||||
// if disable workds path not registered, so it shall not panic
|
||||
// if disable works path not registered, so it shall not panic
|
||||
server.registerHTTPServer()
|
||||
}
|
||||
|
||||
|
@ -1616,6 +1618,86 @@ func Test_NewServer_TLS_FileNotExisted(t *testing.T) {
|
|||
server.Stop()
|
||||
}
|
||||
|
||||
func Test_NewHTTPServer_TLS_TwoWay(t *testing.T) {
|
||||
server := getServer(t)
|
||||
|
||||
Params := ¶mtable.Get().ProxyGrpcServerCfg
|
||||
|
||||
paramtable.Get().Save(Params.TLSMode.Key, "2")
|
||||
paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem")
|
||||
paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key")
|
||||
paramtable.Get().Save(Params.CaPemPath.Key, "../../../configs/cert/ca.pem")
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "8080")
|
||||
|
||||
err := runAndWaitForServerReady(server)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, server.grpcExternalServer)
|
||||
err = server.Stop()
|
||||
assert.Nil(t, err)
|
||||
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "19529")
|
||||
err = runAndWaitForServerReady(server)
|
||||
assert.NotNil(t, err)
|
||||
server.Stop()
|
||||
}
|
||||
|
||||
func Test_NewHTTPServer_TLS_OneWay(t *testing.T) {
|
||||
server := getServer(t)
|
||||
|
||||
Params := ¶mtable.Get().ProxyGrpcServerCfg
|
||||
|
||||
paramtable.Get().Save(Params.TLSMode.Key, "1")
|
||||
paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem")
|
||||
paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key")
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "8080")
|
||||
|
||||
err := runAndWaitForServerReady(server)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, server.grpcExternalServer)
|
||||
err = server.Stop()
|
||||
assert.Nil(t, err)
|
||||
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "19529")
|
||||
err = runAndWaitForServerReady(server)
|
||||
assert.NotNil(t, err)
|
||||
server.Stop()
|
||||
}
|
||||
|
||||
func Test_NewHTTPServer_TLS_FileNotExisted(t *testing.T) {
|
||||
server := getServer(t)
|
||||
|
||||
Params := ¶mtable.Get().ProxyGrpcServerCfg
|
||||
|
||||
paramtable.Get().Save(Params.TLSMode.Key, "1")
|
||||
paramtable.Get().Save(Params.ServerPemPath.Key, "../not/existed/server.pem")
|
||||
paramtable.Get().Save(Params.ServerKeyPath.Key, "../../../configs/cert/server.key")
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Enabled.Key, "true")
|
||||
paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "8080")
|
||||
err := runAndWaitForServerReady(server)
|
||||
assert.NotNil(t, err)
|
||||
server.Stop()
|
||||
|
||||
paramtable.Get().Save(Params.TLSMode.Key, "2")
|
||||
paramtable.Get().Save(Params.ServerPemPath.Key, "../not/existed/server.pem")
|
||||
paramtable.Get().Save(Params.CaPemPath.Key, "../../../configs/cert/ca.pem")
|
||||
err = runAndWaitForServerReady(server)
|
||||
assert.NotNil(t, err)
|
||||
server.Stop()
|
||||
|
||||
paramtable.Get().Save(Params.ServerPemPath.Key, "../../../configs/cert/server.pem")
|
||||
paramtable.Get().Save(Params.CaPemPath.Key, "../not/existed/ca.pem")
|
||||
err = runAndWaitForServerReady(server)
|
||||
assert.NotNil(t, err)
|
||||
server.Stop()
|
||||
|
||||
paramtable.Get().Save(Params.CaPemPath.Key, "service.go")
|
||||
err = runAndWaitForServerReady(server)
|
||||
assert.NotNil(t, err)
|
||||
server.Stop()
|
||||
}
|
||||
|
||||
func Test_NewServer_GetVersion(t *testing.T) {
|
||||
req := &milvuspb.GetVersionRequest{}
|
||||
t.Run("test get version failed", func(t *testing.T) {
|
||||
|
|
|
@ -77,6 +77,22 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
|
|||
log.Warn("GetCurUserFromContext fail", zap.Error(err))
|
||||
return ctx, err
|
||||
}
|
||||
return privilegeInterceptor(ctx, privilegeExt, username, req)
|
||||
}
|
||||
|
||||
func PrivilegeInterceptorWithUsername(ctx context.Context, username string, req interface{}) (context.Context, error) {
|
||||
if !Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
return ctx, nil
|
||||
}
|
||||
log.Debug("PrivilegeInterceptor", zap.String("type", reflect.TypeOf(req).String()))
|
||||
privilegeExt, err := funcutil.GetPrivilegeExtObj(req)
|
||||
if err != nil {
|
||||
log.Warn("GetPrivilegeExtObj err", zap.Error(err))
|
||||
return ctx, nil
|
||||
}
|
||||
return privilegeInterceptor(ctx, privilegeExt, username, req)
|
||||
}
|
||||
func privilegeInterceptor(ctx context.Context, privilegeExt commonpb.PrivilegeExt, username string, req interface{}) (context.Context, error) {
|
||||
if username == util.UserRoot {
|
||||
return ctx, nil
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
@ -820,8 +821,8 @@ func GetCurUserFromContext(ctx context.Context) (string, error) {
|
|||
if !ok {
|
||||
return "", fmt.Errorf("fail to get md from the context")
|
||||
}
|
||||
authorization := md[strings.ToLower(util.HeaderAuthorize)]
|
||||
if len(authorization) < 1 {
|
||||
authorization, ok := md[strings.ToLower(util.HeaderAuthorize)]
|
||||
if !ok || len(authorization) < 1 {
|
||||
return "", fmt.Errorf("fail to get authorization from the md, authorize:[%s]", util.HeaderAuthorize)
|
||||
}
|
||||
token := authorization[0]
|
||||
|
@ -856,6 +857,10 @@ func GetRole(username string) ([]string, error) {
|
|||
return globalMetaCache.GetUserRole(username), nil
|
||||
}
|
||||
|
||||
func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
|
||||
return passwordVerify(ctx, username, rawPwd, globalMetaCache)
|
||||
}
|
||||
|
||||
// PasswordVerify verify password
|
||||
func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCache Cache) bool {
|
||||
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.
|
||||
|
|
|
@ -36,7 +36,7 @@ const (
|
|||
embededBits
|
||||
|
||||
retriableFlag = 1 << 20
|
||||
rootReasonCodeMask = (1 << 16) - 1
|
||||
RootReasonCodeMask = (1 << 16) - 1
|
||||
|
||||
CanceledCode int32 = 10000
|
||||
TimeoutCode int32 = 10001
|
||||
|
@ -125,6 +125,15 @@ var (
|
|||
// field related
|
||||
ErrFieldNotFound = newMilvusError("field not found", 1700, false)
|
||||
|
||||
// high-level restful api related
|
||||
ErrNeedAuthenticate = newMilvusError("user hasn't authenticate", 1800, false)
|
||||
ErrIncorrectParameterFormat = newMilvusError("can only accept json format request", 1801, false)
|
||||
ErrMissingRequiredParameters = newMilvusError("missing required parameters", 1802, false)
|
||||
ErrMarshalCollectionSchema = newMilvusError("fail to marshal collection schema", 1803, false)
|
||||
ErrInvalidInsertData = newMilvusError("fail to deal the insert data", 1804, false)
|
||||
ErrInvalidSearchResult = newMilvusError("fail to parse search result", 1805, false)
|
||||
ErrCheckPrimaryKey = newMilvusError("please check the primary key and its' type can only in [int, string]", 1806, false)
|
||||
|
||||
// Do NOT export this,
|
||||
// never allow programmer using this, keep only for converting unknown error to milvusError
|
||||
errUnexpected = newMilvusError("unexpected error", (1<<16)-1, false)
|
||||
|
|
|
@ -3,6 +3,7 @@ package paramtable
|
|||
type httpConfig struct {
|
||||
Enabled ParamItem `refreshable:"false"`
|
||||
DebugMode ParamItem `refreshable:"false"`
|
||||
Port ParamItem `refreshable:"false"`
|
||||
}
|
||||
|
||||
func (p *httpConfig) init(base *BaseTable) {
|
||||
|
@ -23,4 +24,13 @@ func (p *httpConfig) init(base *BaseTable) {
|
|||
Export: true,
|
||||
}
|
||||
p.DebugMode.Init(base.mgr)
|
||||
|
||||
p.Port = ParamItem{
|
||||
Key: "proxy.http.port",
|
||||
Version: "2.1.0",
|
||||
Doc: "high-level restful api",
|
||||
PanicIfEmpty: false,
|
||||
Export: true,
|
||||
}
|
||||
p.Port.Init(base.mgr)
|
||||
}
|
||||
|
|
|
@ -12,4 +12,5 @@ func TestHTTPConfig_Init(t *testing.T) {
|
|||
cf := params.HTTPCfg
|
||||
assert.Equal(t, cf.Enabled.GetAsBool(), true)
|
||||
assert.Equal(t, cf.DebugMode.GetAsBool(), false)
|
||||
assert.Equal(t, cf.Port.GetValue(), "")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue