Fix data race in gRPC client (#26574)

Signed-off-by: yah01 <yah2er0ne@outlook.com>
pull/26644/head
yah01 2023-08-28 18:26:28 +08:00 committed by GitHub
parent bc6b376c13
commit dd4bc5b6a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 43 additions and 116 deletions

View File

@ -19,7 +19,6 @@ package grpcdatacoordclient
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus/internal/util/grpcclient"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -60,22 +59,10 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) (
return nil, err
}
clientParams := &Params.DataCoordGrpcClientCfg
config := &Params.DataCoordGrpcClientCfg
client := &Client{
grpcClient: &grpcclient.ClientBase[datapb.DataCoordClient]{
ClientMaxRecvSize: clientParams.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: clientParams.ClientMaxSendSize.GetAsInt(),
DialTimeout: clientParams.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: clientParams.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: clientParams.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: "milvus.proto.data.DataCoord",
MaxAttempts: clientParams.MaxAttempts.GetAsInt(),
InitialBackoff: float32(clientParams.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(clientParams.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(clientParams.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: clientParams.CompressionEnabled.GetAsBool(),
},
sess: sess,
grpcClient: grpcclient.NewClientBase[datapb.DataCoordClient](config, "milvus.proto.data.DataCoord"),
sess: sess,
}
client.grpcClient.SetRole(typeutil.DataCoordRole)
client.grpcClient.SetGetAddrFunc(client.getDataCoordAddr)

View File

@ -19,7 +19,6 @@ package grpcdatanodeclient
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -46,22 +45,10 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error)
if addr == "" {
return nil, fmt.Errorf("address is empty")
}
clientParams := &Params.DataNodeGrpcClientCfg
config := &Params.DataNodeGrpcClientCfg
client := &Client{
addr: addr,
grpcClient: &grpcclient.ClientBase[datapb.DataNodeClient]{
ClientMaxRecvSize: clientParams.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: clientParams.ClientMaxSendSize.GetAsInt(),
DialTimeout: clientParams.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: clientParams.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: clientParams.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: "milvus.proto.data.DataNode",
MaxAttempts: clientParams.MaxAttempts.GetAsInt(),
InitialBackoff: float32(clientParams.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(clientParams.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(clientParams.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: clientParams.CompressionEnabled.GetAsBool(),
},
addr: addr,
grpcClient: grpcclient.NewClientBase[datapb.DataNodeClient](config, "milvus.proto.data.DataNode"),
}
client.grpcClient.SetRole(typeutil.DataNodeRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)

View File

@ -19,7 +19,6 @@ package grpcindexnodeclient
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus/internal/util/grpcclient"
"google.golang.org/grpc"
@ -47,22 +46,10 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool)
if addr == "" {
return nil, fmt.Errorf("address is empty")
}
clientParams := &Params.IndexNodeGrpcClientCfg
config := &Params.IndexNodeGrpcClientCfg
client := &Client{
addr: addr,
grpcClient: &grpcclient.ClientBase[indexpb.IndexNodeClient]{
ClientMaxRecvSize: clientParams.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: clientParams.ClientMaxSendSize.GetAsInt(),
DialTimeout: clientParams.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: clientParams.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: clientParams.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: "milvus.proto.index.IndexNode",
MaxAttempts: clientParams.MaxAttempts.GetAsInt(),
InitialBackoff: float32(clientParams.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(clientParams.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(clientParams.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: clientParams.CompressionEnabled.GetAsBool(),
},
addr: addr,
grpcClient: grpcclient.NewClientBase[indexpb.IndexNodeClient](config, "milvus.proto.index.IndexNode"),
}
client.grpcClient.SetRole(typeutil.IndexNodeRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)

View File

@ -19,7 +19,6 @@ package grpcproxyclient
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -46,22 +45,10 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error)
if addr == "" {
return nil, fmt.Errorf("address is empty")
}
clientParams := &Params.ProxyGrpcClientCfg
config := &Params.ProxyGrpcClientCfg
client := &Client{
addr: addr,
grpcClient: &grpcclient.ClientBase[proxypb.ProxyClient]{
ClientMaxRecvSize: clientParams.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: clientParams.ClientMaxSendSize.GetAsInt(),
DialTimeout: clientParams.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: clientParams.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: clientParams.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: "milvus.proto.proxy.Proxy",
MaxAttempts: clientParams.MaxAttempts.GetAsInt(),
InitialBackoff: float32(clientParams.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(clientParams.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(clientParams.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: clientParams.CompressionEnabled.GetAsBool(),
},
addr: addr,
grpcClient: grpcclient.NewClientBase[proxypb.ProxyClient](config, "milvus.proto.proxy.Proxy"),
}
client.grpcClient.SetRole(typeutil.ProxyRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)

View File

@ -19,7 +19,6 @@ package grpcquerycoordclient
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -53,22 +52,10 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) (
log.Debug("QueryCoordClient NewClient failed", zap.Error(err))
return nil, err
}
clientParams := &Params.QueryCoordGrpcClientCfg
config := &Params.QueryCoordGrpcClientCfg
client := &Client{
grpcClient: &grpcclient.ClientBase[querypb.QueryCoordClient]{
ClientMaxRecvSize: clientParams.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: clientParams.ClientMaxSendSize.GetAsInt(),
DialTimeout: clientParams.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: clientParams.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: clientParams.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: "milvus.proto.query.QueryCoord",
MaxAttempts: clientParams.MaxAttempts.GetAsInt(),
InitialBackoff: float32(clientParams.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(clientParams.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(clientParams.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: clientParams.CompressionEnabled.GetAsBool(),
},
sess: sess,
grpcClient: grpcclient.NewClientBase[querypb.QueryCoordClient](config, "milvus.proto.query.QueryCoord"),
sess: sess,
}
client.grpcClient.SetRole(typeutil.QueryCoordRole)
client.grpcClient.SetGetAddrFunc(client.getQueryCoordAddr)

View File

@ -19,7 +19,6 @@ package grpcquerynodeclient
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus/internal/util/grpcclient"
"google.golang.org/grpc"
@ -45,22 +44,10 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error)
if addr == "" {
return nil, fmt.Errorf("addr is empty")
}
clientParams := &paramtable.Get().QueryNodeGrpcClientCfg
config := &paramtable.Get().QueryNodeGrpcClientCfg
client := &Client{
addr: addr,
grpcClient: &grpcclient.ClientBase[querypb.QueryNodeClient]{
ClientMaxRecvSize: clientParams.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: clientParams.ClientMaxSendSize.GetAsInt(),
DialTimeout: clientParams.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: clientParams.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: clientParams.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: "milvus.proto.query.QueryNode",
MaxAttempts: clientParams.MaxAttempts.GetAsInt(),
InitialBackoff: float32(clientParams.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(clientParams.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(clientParams.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: clientParams.CompressionEnabled.GetAsBool(),
},
addr: addr,
grpcClient: grpcclient.NewClientBase[querypb.QueryNodeClient](config, "milvus.proto.query.QueryNode"),
}
client.grpcClient.SetRole(typeutil.QueryNodeRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)

View File

@ -19,7 +19,6 @@ package grpcrootcoordclient
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -60,22 +59,10 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) (
log.Debug("QueryCoordClient NewClient failed", zap.Error(err))
return nil, err
}
clientParams := &Params.RootCoordGrpcClientCfg
config := &Params.RootCoordGrpcClientCfg
client := &Client{
grpcClient: &grpcclient.ClientBase[rootcoordpb.RootCoordClient]{
ClientMaxRecvSize: clientParams.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: clientParams.ClientMaxSendSize.GetAsInt(),
DialTimeout: clientParams.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: clientParams.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: clientParams.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: "milvus.proto.rootcoord.RootCoord",
MaxAttempts: clientParams.MaxAttempts.GetAsInt(),
InitialBackoff: float32(clientParams.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(clientParams.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(clientParams.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: clientParams.CompressionEnabled.GetAsBool(),
},
sess: sess,
grpcClient: grpcclient.NewClientBase[rootcoordpb.RootCoordClient](config, "milvus.proto.rootcoord.RootCoord"),
sess: sess,
}
client.grpcClient.SetRole(typeutil.RootCoordRole)
client.grpcClient.SetGetAddrFunc(client.getRootCoordAddr)

View File

@ -75,7 +75,7 @@ type ClientBase[T interface {
grpcClient T
encryption bool
addr string
addr atomic.String
conn *grpc.ClientConn
grpcClientMtx sync.RWMutex
role string
@ -98,6 +98,24 @@ type ClientBase[T interface {
sf singleflight.Group
}
func NewClientBase[T interface {
GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error)
}](config *paramtable.GrpcClientConfig, serviceName string) *ClientBase[T] {
return &ClientBase[T]{
ClientMaxRecvSize: config.ClientMaxRecvSize.GetAsInt(),
ClientMaxSendSize: config.ClientMaxSendSize.GetAsInt(),
DialTimeout: config.DialTimeout.GetAsDuration(time.Millisecond),
KeepAliveTime: config.KeepAliveTime.GetAsDuration(time.Millisecond),
KeepAliveTimeout: config.KeepAliveTimeout.GetAsDuration(time.Millisecond),
RetryServiceNameConfig: serviceName,
MaxAttempts: config.MaxAttempts.GetAsInt(),
InitialBackoff: float32(config.InitialBackoff.GetAsFloat()),
MaxBackoff: float32(config.MaxBackoff.GetAsFloat()),
BackoffMultiplier: float32(config.BackoffMultiplier.GetAsFloat()),
CompressionEnabled: config.CompressionEnabled.GetAsBool(),
}
}
// SetRole sets role of client
func (c *ClientBase[T]) SetRole(role string) {
c.role = role
@ -110,7 +128,7 @@ func (c *ClientBase[T]) GetRole() string {
// GetAddr returns address of client
func (c *ClientBase[T]) GetAddr() string {
return c.addr
return c.addr.Load()
}
// SetGetAddrFunc sets getAddrFunc of client
@ -165,7 +183,7 @@ func (c *ClientBase[T]) resetConnection(client T) {
_ = c.conn.Close()
}
c.conn = nil
c.addr = ""
c.addr.Store("")
c.grpcClient = generic.Zero[T]()
}
@ -288,7 +306,7 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
}
c.conn = conn
c.addr = addr
c.addr.Store(addr)
c.grpcClient = c.newGrpcClient(c.conn)
return nil
}