mirror of https://github.com/milvus-io/milvus.git
Accelerate server id interceptor validation (#26468)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/26483/head
parent
3be4ac4022
commit
c3128aaef3
|
@ -32,6 +32,7 @@ import (
|
|||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -54,6 +55,8 @@ type Server struct {
|
|||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
wg sync.WaitGroup
|
||||
dataCoord types.DataCoordComponent
|
||||
|
||||
|
@ -154,13 +157,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
|
|||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)),
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
otelgrpc.StreamServerInterceptor(opts...),
|
||||
logutil.StreamTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)))
|
||||
indexpb.RegisterIndexCoordServer(s.grpcServer, s)
|
||||
datapb.RegisterDataCoordServer(s.grpcServer, s)
|
||||
|
|
|
@ -33,6 +33,7 @@ import (
|
|||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -63,6 +64,8 @@ type Server struct {
|
|||
etcdCli *clientv3.Client
|
||||
factory dependency.Factory
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
rootCoord types.RootCoord
|
||||
dataCoord types.DataCoord
|
||||
|
||||
|
@ -138,13 +141,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
|
|||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)),
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
otelgrpc.StreamServerInterceptor(opts...),
|
||||
logutil.StreamTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)))
|
||||
datapb.RegisterDataNodeServer(s.grpcServer, s)
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/tracer"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -54,6 +55,8 @@ type Server struct {
|
|||
grpcServer *grpc.Server
|
||||
grpcErrChan chan error
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
loopCtx context.Context
|
||||
loopCancel func()
|
||||
loopWg sync.WaitGroup
|
||||
|
@ -110,13 +113,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
|
|||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)),
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
otelgrpc.StreamServerInterceptor(opts...),
|
||||
logutil.StreamTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)))
|
||||
indexpb.RegisterIndexNodeServer(s.grpcServer, s)
|
||||
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
|
||||
|
|
|
@ -64,6 +64,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/logutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -93,6 +94,8 @@ type Server struct {
|
|||
grpcInternalServer *grpc.Server
|
||||
grpcExternalServer *grpc.Server
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
etcdCli *clientv3.Client
|
||||
rootCoordClient types.RootCoord
|
||||
dataCoordClient types.DataCoord
|
||||
|
@ -314,11 +317,21 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
|
|||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)),
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
interceptor.ClusterValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)))
|
||||
proxypb.RegisterProxyServer(s.grpcInternalServer, s)
|
||||
grpc_health_v1.RegisterHealthServer(s.grpcInternalServer, s)
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/interceptor"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -56,6 +57,8 @@ type Server struct {
|
|||
loopCancel context.CancelFunc
|
||||
grpcServer *grpc.Server
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
grpcErrChan chan error
|
||||
|
||||
queryCoord types.QueryCoordComponent
|
||||
|
@ -228,13 +231,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
|
|||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)),
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
otelgrpc.StreamServerInterceptor(opts...),
|
||||
logutil.StreamTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)))
|
||||
querypb.RegisterQueryCoordServer(s.grpcServer, s)
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/interceptor"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -60,6 +61,8 @@ type Server struct {
|
|||
cancel context.CancelFunc
|
||||
grpcErrChan chan error
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
grpcServer *grpc.Server
|
||||
|
||||
etcdCli *clientv3.Client
|
||||
|
@ -183,13 +186,23 @@ func (s *Server) startGrpcLoop(grpcPort int) {
|
|||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)),
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
otelgrpc.StreamServerInterceptor(opts...),
|
||||
logutil.StreamTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)))
|
||||
querypb.RegisterQueryNodeServer(s.grpcServer, s)
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/interceptor"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -61,6 +62,8 @@ type Server struct {
|
|||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
etcdCli *clientv3.Client
|
||||
dataCoord types.DataCoord
|
||||
queryCoord types.QueryCoord
|
||||
|
@ -255,13 +258,23 @@ func (s *Server) startGrpcLoop(port int) {
|
|||
otelgrpc.UnaryServerInterceptor(opts...),
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(),
|
||||
interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)),
|
||||
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
|
||||
otelgrpc.StreamServerInterceptor(opts...),
|
||||
logutil.StreamTraceLoggerInterceptor,
|
||||
interceptor.ClusterValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(),
|
||||
interceptor.ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
if s.serverID.Load() == 0 {
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
}
|
||||
return s.serverID.Load()
|
||||
}),
|
||||
)))
|
||||
rootcoordpb.RegisterRootCoordServer(s.grpcServer, s)
|
||||
|
||||
|
|
|
@ -25,14 +25,15 @@ import (
|
|||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
const ServerIDKey = "ServerID"
|
||||
|
||||
type GetServerIDFunc func() int64
|
||||
|
||||
// ServerIDValidationUnaryServerInterceptor returns a new unary server interceptor that
|
||||
// verifies whether the target server ID of request matches with the server's ID and rejects it accordingly.
|
||||
func ServerIDValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
func ServerIDValidationUnaryServerInterceptor(fn GetServerIDFunc) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
|
@ -46,8 +47,9 @@ func ServerIDValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
|||
if err != nil {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
if serverID != paramtable.GetNodeID() {
|
||||
return nil, merr.WrapErrServerIDMismatch(serverID, paramtable.GetNodeID())
|
||||
actualServerID := fn()
|
||||
if serverID != actualServerID {
|
||||
return nil, merr.WrapErrServerIDMismatch(serverID, actualServerID)
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
@ -55,7 +57,7 @@ func ServerIDValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
|||
|
||||
// ServerIDValidationStreamServerInterceptor returns a new streaming server interceptor that
|
||||
// verifies whether the target server ID of request matches with the server's ID and rejects it accordingly.
|
||||
func ServerIDValidationStreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
func ServerIDValidationStreamServerInterceptor(fn GetServerIDFunc) grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
md, ok := metadata.FromIncomingContext(ss.Context())
|
||||
if !ok {
|
||||
|
@ -69,8 +71,9 @@ func ServerIDValidationStreamServerInterceptor() grpc.StreamServerInterceptor {
|
|||
if err != nil {
|
||||
return handler(srv, ss)
|
||||
}
|
||||
if serverID != paramtable.GetNodeID() {
|
||||
return merr.WrapErrServerIDMismatch(serverID, paramtable.GetNodeID())
|
||||
actualServerID := fn()
|
||||
if serverID != actualServerID {
|
||||
return merr.WrapErrServerIDMismatch(serverID, actualServerID)
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
|
|
@ -78,7 +78,9 @@ func TestServerIDInterceptor(t *testing.T) {
|
|||
return nil, nil
|
||||
}
|
||||
serverInfo := &grpc.UnaryServerInfo{FullMethod: method}
|
||||
interceptor := ServerIDValidationUnaryServerInterceptor()
|
||||
interceptor := ServerIDValidationUnaryServerInterceptor(func() int64 {
|
||||
return paramtable.GetNodeID()
|
||||
})
|
||||
|
||||
// no md in context
|
||||
_, err := interceptor(context.Background(), req, serverInfo, handler)
|
||||
|
@ -112,7 +114,9 @@ func TestServerIDInterceptor(t *testing.T) {
|
|||
handler := func(srv interface{}, stream grpc.ServerStream) error {
|
||||
return nil
|
||||
}
|
||||
interceptor := ServerIDValidationStreamServerInterceptor()
|
||||
interceptor := ServerIDValidationStreamServerInterceptor(func() int64 {
|
||||
return paramtable.GetNodeID()
|
||||
})
|
||||
|
||||
// no md in context
|
||||
err := interceptor(nil, newMockSS(context.Background()), nil, handler)
|
||||
|
|
Loading…
Reference in New Issue