mirror of https://github.com/milvus-io/milvus.git
Support to modify the `context` param in the hook interceptor (#19495)
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>pull/19504/head
parent
377f856833
commit
9d40be7e67
|
@ -5,7 +5,7 @@ import "context"
|
|||
type Hook interface {
|
||||
Init(params map[string]string) error
|
||||
Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error)
|
||||
Before(ctx context.Context, req interface{}, fullMethod string) error
|
||||
Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error)
|
||||
After(ctx context.Context, result interface{}, err error, fullMethod string) error
|
||||
Release()
|
||||
}
|
||||
|
|
|
@ -22,8 +22,8 @@ func (d defaultHook) Mock(ctx context.Context, req interface{}, fullMethod strin
|
|||
return false, nil, nil
|
||||
}
|
||||
|
||||
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) error {
|
||||
return nil
|
||||
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (d defaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
|
||||
|
@ -72,6 +72,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
|||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
var (
|
||||
fullMethod = info.FullMethod
|
||||
newCtx context.Context
|
||||
isMock bool
|
||||
mockResp interface{}
|
||||
realResp interface{}
|
||||
|
@ -83,11 +84,11 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
|
|||
return mockResp, err
|
||||
}
|
||||
|
||||
if err = hoo.Before(ctx, req, fullMethod); err != nil {
|
||||
if newCtx, err = hoo.Before(ctx, req, fullMethod); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realResp, realErr = handler(ctx, req)
|
||||
if err = hoo.After(ctx, realResp, realErr, fullMethod); err != nil {
|
||||
realResp, realErr = handler(newCtx, req)
|
||||
if err = hoo.After(newCtx, realResp, realErr, fullMethod); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return realResp, realErr
|
||||
|
|
|
@ -36,19 +36,23 @@ type req struct {
|
|||
method string
|
||||
}
|
||||
|
||||
type BeforeMockCtxKey int
|
||||
|
||||
type beforeMock struct {
|
||||
defaultHook
|
||||
method string
|
||||
err error
|
||||
method string
|
||||
ctxKey BeforeMockCtxKey
|
||||
ctxValue string
|
||||
err error
|
||||
}
|
||||
|
||||
func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) error {
|
||||
func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) (context.Context, error) {
|
||||
re, ok := r.(*req)
|
||||
if !ok {
|
||||
return errors.New("r is invalid type")
|
||||
return ctx, errors.New("r is invalid type")
|
||||
}
|
||||
re.method = b.method
|
||||
return b.err
|
||||
return context.WithValue(ctx, b.ctxKey, b.ctxValue), b.err
|
||||
}
|
||||
|
||||
type resp struct {
|
||||
|
@ -80,7 +84,7 @@ func TestHookInterceptor(t *testing.T) {
|
|||
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
|
||||
r = &req{method: "req"}
|
||||
re = &resp{method: "resp"}
|
||||
beforeHoo = beforeMock{method: "before", err: errors.New("before")}
|
||||
beforeHoo = beforeMock{method: "before", ctxKey: 100, ctxValue: "hook", err: errors.New("before")}
|
||||
afterHoo = afterMock{method: "after", err: errors.New("after")}
|
||||
|
||||
res interface{}
|
||||
|
@ -101,6 +105,15 @@ func TestHookInterceptor(t *testing.T) {
|
|||
assert.Equal(t, r.method, beforeHoo.method)
|
||||
assert.Equal(t, err, beforeHoo.err)
|
||||
|
||||
beforeHoo.err = nil
|
||||
hoo = beforeHoo
|
||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
|
||||
return nil, nil
|
||||
})
|
||||
assert.Equal(t, r.method, beforeHoo.method)
|
||||
assert.Equal(t, err, beforeHoo.err)
|
||||
|
||||
hoo = afterHoo
|
||||
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
|
||||
return re, nil
|
||||
|
|
Loading…
Reference in New Issue