mirror of https://github.com/milvus-io/milvus.git
Support to trace the grpc request (#28349)
Signed-off-by: SimFG <bang.fu@zilliz.com>pull/28402/head
parent
87465f07a7
commit
cfb6edea61
|
@ -540,6 +540,7 @@ common:
|
|||
info: 500 # minimum milliseconds for printing durations in info level
|
||||
warn: 1000 # minimum milliseconds for printing durations in warn level
|
||||
ttMsgEnabled: true # Whether the instance disable sending ts messages
|
||||
traceLogMode: 0 # trace request info, 0: none, 1: simple request info, like collection/partition/database name, 2: request detail
|
||||
|
||||
# QuotaConfig, configurations of Milvus quota and limits.
|
||||
# By default, we enable:
|
||||
|
|
|
@ -123,7 +123,6 @@ func authenticate(c *gin.Context) {
|
|||
if !proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
return
|
||||
}
|
||||
// TODO fubang
|
||||
username, password, ok := httpserver.ParseUsernamePassword(c)
|
||||
if ok {
|
||||
if proxy.PasswordVerify(c, username, password) {
|
||||
|
@ -188,8 +187,8 @@ func (s *Server) startHTTPServer(errChan chan error) {
|
|||
return
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
app := ginHandler.Group("/v1", authenticate)
|
||||
}, authenticate, proxy.HTTPTraceLog)
|
||||
app := ginHandler.Group("/v1")
|
||||
httpserver.NewHandlers(s.proxy).RegisterRoutesToV1(app)
|
||||
s.httpServer = &http.Server{Handler: ginHandler, ReadHeaderTimeout: time.Second}
|
||||
errChan <- nil
|
||||
|
@ -247,6 +246,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
|
|||
logutil.UnaryTraceLoggerInterceptor,
|
||||
proxy.RateLimitInterceptor(limiter),
|
||||
accesslog.UnaryAccessLoggerInterceptor,
|
||||
proxy.TraceLogInterceptor,
|
||||
proxy.KeepActiveInterceptor,
|
||||
)),
|
||||
}
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/requestutil"
|
||||
)
|
||||
|
||||
func TraceLogInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
switch Params.CommonCfg.TraceLogMode.GetAsInt() {
|
||||
case 0: // none
|
||||
return handler(ctx, req)
|
||||
case 1: // simple info
|
||||
fields := GetRequestBaseInfo(ctx, req, info, false)
|
||||
log.Ctx(ctx).Info("trace info: simple", fields...)
|
||||
return handler(ctx, req)
|
||||
case 2: // detail info
|
||||
fields := GetRequestBaseInfo(ctx, req, info, true)
|
||||
fields = append(fields, GetRequestFieldWithoutSensitiveInfo(req))
|
||||
log.Ctx(ctx).Info("trace info: detail", fields...)
|
||||
return handler(ctx, req)
|
||||
default:
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func GetRequestBaseInfo(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, skipBaseRequestInfo bool) []zap.Field {
|
||||
var fields []zap.Field
|
||||
|
||||
_, requestName := path.Split(info.FullMethod)
|
||||
fields = append(fields, zap.String("request_name", requestName))
|
||||
|
||||
username, err := GetCurUserFromContext(ctx)
|
||||
if err == nil && username != "" {
|
||||
fields = append(fields, zap.String("username", username))
|
||||
}
|
||||
|
||||
if !skipBaseRequestInfo {
|
||||
for baseInfoName, f := range requestutil.TraceLogBaseInfoFuncMap {
|
||||
baseInfo, ok := f(req)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fields = append(fields, zap.Any(baseInfoName, baseInfo))
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func GetRequestFieldWithoutSensitiveInfo(req interface{}) zap.Field {
|
||||
createCredentialReq, ok := req.(*milvuspb.CreateCredentialRequest)
|
||||
if ok {
|
||||
return zap.Any("request", &milvuspb.CreateCredentialRequest{
|
||||
Base: createCredentialReq.Base,
|
||||
Username: createCredentialReq.Username,
|
||||
CreatedUtcTimestamps: createCredentialReq.CreatedUtcTimestamps,
|
||||
ModifiedUtcTimestamps: createCredentialReq.ModifiedUtcTimestamps,
|
||||
})
|
||||
}
|
||||
updateCredentialReq, ok := req.(*milvuspb.UpdateCredentialRequest)
|
||||
if ok {
|
||||
return zap.Any("request", &milvuspb.UpdateCredentialRequest{
|
||||
Base: updateCredentialReq.Base,
|
||||
Username: updateCredentialReq.Username,
|
||||
CreatedUtcTimestamps: updateCredentialReq.CreatedUtcTimestamps,
|
||||
ModifiedUtcTimestamps: updateCredentialReq.ModifiedUtcTimestamps,
|
||||
})
|
||||
}
|
||||
return zap.Any("request", req)
|
||||
}
|
||||
|
||||
func HTTPTraceLog(ctx *gin.Context) {
|
||||
// TODO trace http request info
|
||||
ctx.Next()
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func TestTraceLogInterceptor(t *testing.T) {
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// none
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "0")
|
||||
_, _ = TraceLogInterceptor(context.Background(), &milvuspb.ShowCollectionsRequest{}, &grpc.UnaryServerInfo{}, handler)
|
||||
|
||||
// invalid mode
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "10")
|
||||
_, _ = TraceLogInterceptor(context.Background(), &milvuspb.ShowCollectionsRequest{}, &grpc.UnaryServerInfo{}, handler)
|
||||
|
||||
// simple mode
|
||||
ctx := GetContext(context.Background(), fmt.Sprintf("%s%s%s", "foo", util.CredentialSeperator, "FOO123456"))
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "1")
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: "db",
|
||||
CollectionName: "col1",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/ShowCollections",
|
||||
}, handler)
|
||||
}
|
||||
|
||||
// detail mode
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "2")
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: "db",
|
||||
CollectionName: "col1",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/ShowCollections",
|
||||
}, handler)
|
||||
}
|
||||
|
||||
{
|
||||
f1 := GetRequestFieldWithoutSensitiveInfo(&milvuspb.CreateCredentialRequest{
|
||||
Username: "foo",
|
||||
Password: "123456",
|
||||
})
|
||||
assert.NotContains(t, strings.ToLower(fmt.Sprint(f1.Interface)), "password")
|
||||
|
||||
f2 := GetRequestFieldWithoutSensitiveInfo(&milvuspb.UpdateCredentialRequest{
|
||||
Username: "foo",
|
||||
OldPassword: "123456",
|
||||
NewPassword: "FOO123456",
|
||||
})
|
||||
assert.NotContains(t, strings.ToLower(fmt.Sprint(f2.Interface)), "password")
|
||||
}
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "0")
|
||||
}
|
|
@ -219,6 +219,7 @@ type commonConfig struct {
|
|||
LockSlowLogWarnThreshold ParamItem `refreshable:"true"`
|
||||
|
||||
TTMsgEnabled ParamItem `refreshable:"true"`
|
||||
TraceLogMode ParamItem `refreshable:"true"`
|
||||
}
|
||||
|
||||
func (p *commonConfig) init(base *BaseTable) {
|
||||
|
@ -633,6 +634,14 @@ like the old password verification when updating the credential`,
|
|||
Doc: "Whether the instance disable sending ts messages",
|
||||
}
|
||||
p.TTMsgEnabled.Init(base.mgr)
|
||||
|
||||
p.TraceLogMode = ParamItem{
|
||||
Key: "common.traceLogMode",
|
||||
Version: "2.3.4",
|
||||
DefaultValue: "0",
|
||||
Doc: "trace request info",
|
||||
}
|
||||
p.TraceLogMode.Init(base.mgr)
|
||||
}
|
||||
|
||||
type traceConfig struct {
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package requestutil
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
||||
type CollectionNameGetter interface {
|
||||
GetCollectionName() string
|
||||
}
|
||||
|
||||
func GetCollectionNameFromRequest(req any) (any, bool) {
|
||||
getter, ok := req.(CollectionNameGetter)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return getter.GetCollectionName(), true
|
||||
}
|
||||
|
||||
type DBNameGetter interface {
|
||||
GetDbName() string
|
||||
}
|
||||
|
||||
func GetDbNameFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(DBNameGetter)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return getter.GetDbName(), true
|
||||
}
|
||||
|
||||
type PartitionNameGetter interface {
|
||||
GetPartitionName() string
|
||||
}
|
||||
|
||||
func GetPartitionNameFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(PartitionNameGetter)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return getter.GetPartitionName(), true
|
||||
}
|
||||
|
||||
type PartitionNamesGetter interface {
|
||||
GetPartitionNames() []string
|
||||
}
|
||||
|
||||
func GetPartitionNamesFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(PartitionNamesGetter)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return getter.GetPartitionNames(), true
|
||||
}
|
||||
|
||||
type FieldNameGetter interface {
|
||||
GetFieldName() string
|
||||
}
|
||||
|
||||
func GetFieldNameFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(FieldNameGetter)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return getter.GetFieldName(), true
|
||||
}
|
||||
|
||||
type OutputFieldsGetter interface {
|
||||
GetOutputFields() []string
|
||||
}
|
||||
|
||||
func GetOutputFieldsFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(OutputFieldsGetter)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return getter.GetOutputFields(), true
|
||||
}
|
||||
|
||||
type QueryParamsGetter interface {
|
||||
GetQueryParams() []*commonpb.KeyValuePair
|
||||
}
|
||||
|
||||
func GetQueryParamsFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(QueryParamsGetter)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return getter.GetQueryParams(), true
|
||||
}
|
||||
|
||||
type ExprGetter interface {
|
||||
GetExpr() string
|
||||
}
|
||||
|
||||
func GetExprFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(ExprGetter)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return getter.GetExpr(), true
|
||||
}
|
||||
|
||||
type SearchParamsGetter interface {
|
||||
GetSearchParams() []*commonpb.KeyValuePair
|
||||
}
|
||||
|
||||
func GetSearchParamsFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(SearchParamsGetter)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return getter.GetSearchParams(), true
|
||||
}
|
||||
|
||||
type DSLGetter interface {
|
||||
GetDsl() string
|
||||
}
|
||||
|
||||
func GetDSLFromRequest(req interface{}) (any, bool) {
|
||||
getter, ok := req.(DSLGetter)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return getter.GetDsl(), true
|
||||
}
|
||||
|
||||
var TraceLogBaseInfoFuncMap = map[string]func(interface{}) (any, bool){
|
||||
"collection_name": GetCollectionNameFromRequest,
|
||||
"db_name": GetDbNameFromRequest,
|
||||
"partition_name": GetPartitionNameFromRequest,
|
||||
"partition_names": GetPartitionNamesFromRequest,
|
||||
"field_name": GetFieldNameFromRequest,
|
||||
"output_fields": GetOutputFieldsFromRequest,
|
||||
"query_params": GetQueryParamsFromRequest,
|
||||
"expr": GetExprFromRequest,
|
||||
"search_params": GetSearchParamsFromRequest,
|
||||
"dsl": GetDSLFromRequest,
|
||||
}
|
|
@ -0,0 +1,457 @@
|
|||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package requestutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
)
|
||||
|
||||
func TestGetCollectionNameFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
args: args{
|
||||
req: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: "foo",
|
||||
},
|
||||
},
|
||||
want: "foo",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetCollectionNameFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetCollectionNameFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetCollectionNameFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDbNameFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
args: args{
|
||||
req: &milvuspb.CreateDatabaseRequest{
|
||||
DbName: "foo",
|
||||
},
|
||||
},
|
||||
want: "foo",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetDbNameFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetDbNameFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetDbNameFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPartitionNameFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
args: args{
|
||||
req: &milvuspb.CreatePartitionRequest{
|
||||
PartitionName: "baz",
|
||||
},
|
||||
},
|
||||
want: "baz",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetPartitionNameFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetPartitionNameFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetPartitionNameFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPartitionNamesFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
args: args{
|
||||
req: &milvuspb.SearchRequest{
|
||||
PartitionNames: []string{"baz", "faz"},
|
||||
},
|
||||
},
|
||||
want: []string{"baz", "faz"},
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetPartitionNamesFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetPartitionNamesFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetPartitionNamesFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFieldNameFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
FieldName: "foo",
|
||||
},
|
||||
},
|
||||
want: "foo",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetFieldNameFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetFieldNameFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetFieldNameFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOutputFieldsFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
req: &milvuspb.SearchRequest{
|
||||
OutputFields: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
want: []string{"foo", "bar"},
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetOutputFieldsFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetOutputFieldsFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetOutputFieldsFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetQueryParamsFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
req: &milvuspb.QueryRequest{
|
||||
QueryParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "foo",
|
||||
Value: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "foo",
|
||||
Value: "bar",
|
||||
},
|
||||
},
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetQueryParamsFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetQueryParamsFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetQueryParamsFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExprFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
req: &milvuspb.QueryRequest{
|
||||
Expr: "foo",
|
||||
},
|
||||
},
|
||||
want: "foo",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetExprFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetExprFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetExprFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSearchParamsFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
req: &milvuspb.SearchRequest{
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "foo",
|
||||
Value: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "foo",
|
||||
Value: "bar",
|
||||
},
|
||||
},
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetSearchParamsFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetSearchParamsFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetSearchParamsFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDSLFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
req interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want any
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
req: &milvuspb.SearchRequest{
|
||||
Dsl: "foo",
|
||||
},
|
||||
},
|
||||
want: "foo",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "fail",
|
||||
args: args{
|
||||
req: &commonpb.Status{},
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetDSLFromRequest(tt.args.req)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetDSLFromRequest() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetDSLFromRequest() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue