mirror of https://github.com/milvus-io/milvus.git
Support to modify the `context` param in the hook interceptor (#19495)
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>pull/19504/head
parent
377f856833
commit
9d40be7e67
|
@ -5,7 +5,7 @@ import "context"
|
||||||
type Hook interface {
|
type Hook interface {
|
||||||
Init(params map[string]string) error
|
Init(params map[string]string) error
|
||||||
Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error)
|
Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error)
|
||||||
Before(ctx context.Context, req interface{}, fullMethod string) error
|
Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error)
|
||||||
After(ctx context.Context, result interface{}, err error, fullMethod string) error
|
After(ctx context.Context, result interface{}, err error, fullMethod string) error
|
||||||
Release()
|
Release()
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,8 +22,8 @@ func (d defaultHook) Mock(ctx context.Context, req interface{}, fullMethod strin
|
||||||
return false, nil, nil
|
return false, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) error {
|
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) {
|
||||||
return nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d defaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
|
func (d defaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
|
||||||
|
@ -72,6 +72,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
||||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
var (
|
var (
|
||||||
fullMethod = info.FullMethod
|
fullMethod = info.FullMethod
|
||||||
|
newCtx context.Context
|
||||||
isMock bool
|
isMock bool
|
||||||
mockResp interface{}
|
mockResp interface{}
|
||||||
realResp interface{}
|
realResp interface{}
|
||||||
|
@ -83,11 +84,11 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
||||||
return mockResp, err
|
return mockResp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = hoo.Before(ctx, req, fullMethod); err != nil {
|
if newCtx, err = hoo.Before(ctx, req, fullMethod); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
realResp, realErr = handler(ctx, req)
|
realResp, realErr = handler(newCtx, req)
|
||||||
if err = hoo.After(ctx, realResp, realErr, fullMethod); err != nil {
|
if err = hoo.After(newCtx, realResp, realErr, fullMethod); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return realResp, realErr
|
return realResp, realErr
|
||||||
|
|
|
@ -36,19 +36,23 @@ type req struct {
|
||||||
method string
|
method string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BeforeMockCtxKey int
|
||||||
|
|
||||||
type beforeMock struct {
|
type beforeMock struct {
|
||||||
defaultHook
|
defaultHook
|
||||||
method string
|
method string
|
||||||
err error
|
ctxKey BeforeMockCtxKey
|
||||||
|
ctxValue string
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) error {
|
func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) (context.Context, error) {
|
||||||
re, ok := r.(*req)
|
re, ok := r.(*req)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("r is invalid type")
|
return ctx, errors.New("r is invalid type")
|
||||||
}
|
}
|
||||||
re.method = b.method
|
re.method = b.method
|
||||||
return b.err
|
return context.WithValue(ctx, b.ctxKey, b.ctxValue), b.err
|
||||||
}
|
}
|
||||||
|
|
||||||
type resp struct {
|
type resp struct {
|
||||||
|
@ -80,7 +84,7 @@ func TestHookInterceptor(t *testing.T) {
|
||||||
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
|
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
|
||||||
r = &req{method: "req"}
|
r = &req{method: "req"}
|
||||||
re = &resp{method: "resp"}
|
re = &resp{method: "resp"}
|
||||||
beforeHoo = beforeMock{method: "before", err: errors.New("before")}
|
beforeHoo = beforeMock{method: "before", ctxKey: 100, ctxValue: "hook", err: errors.New("before")}
|
||||||
afterHoo = afterMock{method: "after", err: errors.New("after")}
|
afterHoo = afterMock{method: "after", err: errors.New("after")}
|
||||||
|
|
||||||
res interface{}
|
res interface{}
|
||||||
|
@ -101,6 +105,15 @@ func TestHookInterceptor(t *testing.T) {
|
||||||
assert.Equal(t, r.method, beforeHoo.method)
|
assert.Equal(t, r.method, beforeHoo.method)
|
||||||
assert.Equal(t, err, beforeHoo.err)
|
assert.Equal(t, err, beforeHoo.err)
|
||||||
|
|
||||||
|
beforeHoo.err = nil
|
||||||
|
hoo = beforeHoo
|
||||||
|
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
assert.Equal(t, r.method, beforeHoo.method)
|
||||||
|
assert.Equal(t, err, beforeHoo.err)
|
||||||
|
|
||||||
hoo = afterHoo
|
hoo = afterHoo
|
||||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
||||||
return re, nil
|
return re, nil
|
||||||
|
|
Loading…
Reference in New Issue