mirror of https://github.com/milvus-io/milvus.git
161 lines
3.9 KiB
Go
161 lines
3.9 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"google.golang.org/grpc"
|
|
|
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
|
)
|
|
|
|
func TestInitHook(t *testing.T) {
|
|
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
|
|
initHook()
|
|
assert.IsType(t, defaultHook{}, hoo)
|
|
|
|
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
|
|
err := initHook()
|
|
assert.Error(t, err)
|
|
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
|
|
}
|
|
|
|
type mockHook struct {
|
|
defaultHook
|
|
mockRes interface{}
|
|
mockErr error
|
|
}
|
|
|
|
func (m mockHook) Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error) {
|
|
return true, m.mockRes, m.mockErr
|
|
}
|
|
|
|
type req struct {
|
|
method string
|
|
}
|
|
|
|
type BeforeMockCtxKey int
|
|
|
|
type beforeMock struct {
|
|
defaultHook
|
|
method string
|
|
ctxKey BeforeMockCtxKey
|
|
ctxValue string
|
|
err error
|
|
}
|
|
|
|
func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) (context.Context, error) {
|
|
re, ok := r.(*req)
|
|
if !ok {
|
|
return ctx, errors.New("r is invalid type")
|
|
}
|
|
re.method = b.method
|
|
return context.WithValue(ctx, b.ctxKey, b.ctxValue), b.err
|
|
}
|
|
|
|
type resp struct {
|
|
method string
|
|
}
|
|
|
|
type afterMock struct {
|
|
defaultHook
|
|
method string
|
|
err error
|
|
}
|
|
|
|
func (a afterMock) After(ctx context.Context, r interface{}, err error, fullMethod string) error {
|
|
re, ok := r.(*resp)
|
|
if !ok {
|
|
return errors.New("r is invalid type")
|
|
}
|
|
re.method = a.method
|
|
return a.err
|
|
}
|
|
|
|
func TestHookInterceptor(t *testing.T) {
|
|
var (
|
|
ctx = context.Background()
|
|
info = &grpc.UnaryServerInfo{
|
|
FullMethod: "test",
|
|
}
|
|
emptyFullMethod = &grpc.UnaryServerInfo{
|
|
FullMethod: "",
|
|
}
|
|
interceptor = UnaryServerHookInterceptor()
|
|
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
|
|
r = &req{method: "req"}
|
|
re = &resp{method: "resp"}
|
|
beforeHoo = beforeMock{method: "before", ctxKey: 100, ctxValue: "hook", err: errors.New("before")}
|
|
afterHoo = afterMock{method: "after", err: errors.New("after")}
|
|
|
|
res interface{}
|
|
err error
|
|
)
|
|
|
|
hoo = mockHoo
|
|
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
return nil, nil
|
|
})
|
|
assert.Equal(t, res, mockHoo.mockRes)
|
|
assert.Equal(t, err, mockHoo.mockErr)
|
|
res, err = interceptor(ctx, "request", emptyFullMethod, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
return nil, nil
|
|
})
|
|
assert.Equal(t, res, mockHoo.mockRes)
|
|
assert.Equal(t, err, mockHoo.mockErr)
|
|
|
|
hoo = beforeHoo
|
|
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
return nil, nil
|
|
})
|
|
assert.Equal(t, r.method, beforeHoo.method)
|
|
assert.Equal(t, err, beforeHoo.err)
|
|
|
|
beforeHoo.err = nil
|
|
hoo = beforeHoo
|
|
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
|
|
return nil, nil
|
|
})
|
|
assert.Equal(t, r.method, beforeHoo.method)
|
|
assert.Equal(t, err, beforeHoo.err)
|
|
|
|
hoo = afterHoo
|
|
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
|
return re, nil
|
|
})
|
|
assert.Equal(t, re.method, afterHoo.method)
|
|
assert.Equal(t, err, afterHoo.err)
|
|
|
|
hoo = defaultHook{}
|
|
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
|
return &resp{
|
|
method: r.(*req).method,
|
|
}, nil
|
|
})
|
|
assert.Equal(t, res.(*resp).method, r.method)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestDefaultHook(t *testing.T) {
|
|
d := defaultHook{}
|
|
assert.NoError(t, d.Init(nil))
|
|
{
|
|
_, err := d.VerifyAPIKey("key")
|
|
assert.Error(t, err)
|
|
}
|
|
assert.NotPanics(t, func() {
|
|
d.Release()
|
|
})
|
|
}
|
|
|
|
func TestUpdateProxyFunctionCallMetric(t *testing.T) {
|
|
assert.NotPanics(t, func() {
|
|
updateProxyFunctionCallMetric("/milvus.proto.milvus.MilvusService/Flush")
|
|
updateProxyFunctionCallMetric("Flush")
|
|
updateProxyFunctionCallMetric("")
|
|
})
|
|
}
|