Accelerate server id interceptor validation (#26468)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/26483/head
yihao.dai 2023-08-20 21:24:19 +08:00 committed by GitHub
parent 3be4ac4022
commit c3128aaef3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 121 additions and 23 deletions

View File

@ -32,6 +32,7 @@ import (
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
clientv3 "go.etcd.io/etcd/client/v3"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -54,6 +55,8 @@ type Server struct {
ctx context.Context
cancel context.CancelFunc
serverID atomic.Int64
wg sync.WaitGroup
dataCoord types.DataCoordComponent
@ -154,13 +157,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)))
indexpb.RegisterIndexCoordServer(s.grpcServer, s)
datapb.RegisterDataCoordServer(s.grpcServer, s)

View File

@ -33,6 +33,7 @@ import (
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
clientv3 "go.etcd.io/etcd/client/v3"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -63,6 +64,8 @@ type Server struct {
etcdCli *clientv3.Client
factory dependency.Factory
serverID atomic.Int64
rootCoord types.RootCoord
dataCoord types.DataCoord
@ -138,13 +141,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)))
datapb.RegisterDataNodeServer(s.grpcServer, s)

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus/pkg/tracer"
clientv3 "go.etcd.io/etcd/client/v3"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -54,6 +55,8 @@ type Server struct {
grpcServer *grpc.Server
grpcErrChan chan error
serverID atomic.Int64
loopCtx context.Context
loopCancel func()
loopWg sync.WaitGroup
@ -110,13 +113,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)))
indexpb.RegisterIndexNodeServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -64,6 +64,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/logutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -93,6 +94,8 @@ type Server struct {
grpcInternalServer *grpc.Server
grpcExternalServer *grpc.Server
serverID atomic.Int64
etcdCli *clientv3.Client
rootCoordClient types.RootCoord
dataCoordClient types.DataCoord
@ -314,11 +317,21 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)))
proxypb.RegisterProxyServer(s.grpcInternalServer, s)
grpc_health_v1.RegisterHealthServer(s.grpcInternalServer, s)

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/interceptor"
clientv3 "go.etcd.io/etcd/client/v3"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -56,6 +57,8 @@ type Server struct {
loopCancel context.CancelFunc
grpcServer *grpc.Server
serverID atomic.Int64
grpcErrChan chan error
queryCoord types.QueryCoordComponent
@ -228,13 +231,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)))
querypb.RegisterQueryCoordServer(s.grpcServer, s)

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/interceptor"
clientv3 "go.etcd.io/etcd/client/v3"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -60,6 +61,8 @@ type Server struct {
cancel context.CancelFunc
grpcErrChan chan error
serverID atomic.Int64
grpcServer *grpc.Server
etcdCli *clientv3.Client
@ -183,13 +186,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)))
querypb.RegisterQueryNodeServer(s.grpcServer, s)

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/interceptor"
clientv3 "go.etcd.io/etcd/client/v3"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -61,6 +62,8 @@ type Server struct {
ctx context.Context
cancel context.CancelFunc
serverID atomic.Int64
etcdCli *clientv3.Client
dataCoord types.DataCoord
queryCoord types.QueryCoord
@ -255,13 +258,23 @@ func (s *Server) startGrpcLoop(port int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
if s.serverID.Load() == 0 {
s.serverID.Store(paramtable.GetNodeID())
}
return s.serverID.Load()
}),
)))
rootcoordpb.RegisterRootCoordServer(s.grpcServer, s)

View File

@ -25,14 +25,15 @@ import (
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
const ServerIDKey = "ServerID"
type GetServerIDFunc func() int64
// ServerIDValidationUnaryServerInterceptor returns a new unary server interceptor that
// verifies whether the target server ID of request matches with the server's ID and rejects it accordingly.
func ServerIDValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
func ServerIDValidationUnaryServerInterceptor(fn GetServerIDFunc) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
@ -46,8 +47,9 @@ func ServerIDValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
if err != nil {
return handler(ctx, req)
}
if serverID != paramtable.GetNodeID() {
return nil, merr.WrapErrServerIDMismatch(serverID, paramtable.GetNodeID())
actualServerID := fn()
if serverID != actualServerID {
return nil, merr.WrapErrServerIDMismatch(serverID, actualServerID)
}
return handler(ctx, req)
}
@ -55,7 +57,7 @@ func ServerIDValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
// ServerIDValidationStreamServerInterceptor returns a new streaming server interceptor that
// verifies whether the target server ID of request matches with the server's ID and rejects it accordingly.
func ServerIDValidationStreamServerInterceptor() grpc.StreamServerInterceptor {
func ServerIDValidationStreamServerInterceptor(fn GetServerIDFunc) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
md, ok := metadata.FromIncomingContext(ss.Context())
if !ok {
@ -69,8 +71,9 @@ func ServerIDValidationStreamServerInterceptor() grpc.StreamServerInterceptor {
if err != nil {
return handler(srv, ss)
}
if serverID != paramtable.GetNodeID() {
return merr.WrapErrServerIDMismatch(serverID, paramtable.GetNodeID())
actualServerID := fn()
if serverID != actualServerID {
return merr.WrapErrServerIDMismatch(serverID, actualServerID)
}
return handler(srv, ss)
}

View File

@ -78,7 +78,9 @@ func TestServerIDInterceptor(t *testing.T) {
return nil, nil
}
serverInfo := &grpc.UnaryServerInfo{FullMethod: method}
interceptor := ServerIDValidationUnaryServerInterceptor()
interceptor := ServerIDValidationUnaryServerInterceptor(func() int64 {
return paramtable.GetNodeID()
})
// no md in context
_, err := interceptor(context.Background(), req, serverInfo, handler)
@ -112,7 +114,9 @@ func TestServerIDInterceptor(t *testing.T) {
handler := func(srv interface{}, stream grpc.ServerStream) error {
return nil
}
interceptor := ServerIDValidationStreamServerInterceptor()
interceptor := ServerIDValidationStreamServerInterceptor(func() int64 {
return paramtable.GetNodeID()
})
// no md in context
err := interceptor(nil, newMockSS(context.Background()), nil, handler)