Support to modify the `context` param in the hook interceptor (#19495)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/19504/head
SimFG 2022-09-28 13:26:54 +08:00 committed by GitHub
parent 377f856833
commit 9d40be7e67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 12 deletions

View File

@ -5,7 +5,7 @@ import "context"
type Hook interface {
Init(params map[string]string) error
Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error)
Before(ctx context.Context, req interface{}, fullMethod string) error
Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error)
After(ctx context.Context, result interface{}, err error, fullMethod string) error
Release()
}

View File

@ -22,8 +22,8 @@ func (d defaultHook) Mock(ctx context.Context, req interface{}, fullMethod strin
return false, nil, nil
}
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) error {
return nil
func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) {
return ctx, nil
}
func (d defaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error {
@ -72,6 +72,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var (
fullMethod = info.FullMethod
newCtx context.Context
isMock bool
mockResp interface{}
realResp interface{}
@ -83,11 +84,11 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
return mockResp, err
}
if err = hoo.Before(ctx, req, fullMethod); err != nil {
if newCtx, err = hoo.Before(ctx, req, fullMethod); err != nil {
return nil, err
}
realResp, realErr = handler(ctx, req)
if err = hoo.After(ctx, realResp, realErr, fullMethod); err != nil {
realResp, realErr = handler(newCtx, req)
if err = hoo.After(newCtx, realResp, realErr, fullMethod); err != nil {
return nil, err
}
return realResp, realErr

View File

@ -36,19 +36,23 @@ type req struct {
method string
}
type BeforeMockCtxKey int
type beforeMock struct {
defaultHook
method string
ctxKey BeforeMockCtxKey
ctxValue string
err error
}
func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) error {
func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) (context.Context, error) {
re, ok := r.(*req)
if !ok {
return errors.New("r is invalid type")
return ctx, errors.New("r is invalid type")
}
re.method = b.method
return b.err
return context.WithValue(ctx, b.ctxKey, b.ctxValue), b.err
}
type resp struct {
@ -80,7 +84,7 @@ func TestHookInterceptor(t *testing.T) {
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
r = &req{method: "req"}
re = &resp{method: "resp"}
beforeHoo = beforeMock{method: "before", err: errors.New("before")}
beforeHoo = beforeMock{method: "before", ctxKey: 100, ctxValue: "hook", err: errors.New("before")}
afterHoo = afterMock{method: "after", err: errors.New("after")}
res interface{}
@ -101,6 +105,15 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)
beforeHoo.err = nil
hoo = beforeHoo
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
return nil, nil
})
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)
hoo = afterHoo
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return re, nil