feat: Added code for Internal-tls (#36865)

issue : https://github.com/milvus-io/milvus/issues/36864

I have a few questions regarding my approach.I will consolidate them
here for feedback and review.Thanks

---------

Signed-off-by: Nischay Yadav <nischay.yadav@ibm.com>
Signed-off-by: Nischay <Nischay.Yadav@ibm.com>
pull/37817/head
nish112022 2024-11-20 03:30:32 +05:30 committed by GitHub
parent 83c902992c
commit 484c6b5c44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 643 additions and 21 deletions

View File

@ -297,6 +297,10 @@ func WriteYaml(w io.Writer) {
name: "tls",
header: "\n# Configure the proxy tls enable.",
},
{
name: "internaltls",
header: "\n# Configure the node-tls enable.",
},
{
name: "common",
},

View File

@ -784,6 +784,12 @@ tls:
serverKeyPath: configs/cert/server.key
caPemPath: configs/cert/ca.pem
# Configure the node-tls enable.
internaltls:
serverPemPath: configs/cert/server.pem
serverKeyPath: configs/cert/server.key
caPemPath: configs/cert/ca.pem
common:
defaultPartitionName: _default # Name of the default partition when a collection is created
defaultIndexName: _default_idx # Name of the index when it is created with name unspecified
@ -839,6 +845,7 @@ common:
privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition # Collection level readwrite privileges
admin:
privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition,CreateAlias,DropAlias # Collection level admin privileges
internaltlsEnabled: false
tlsMode: 0
session:
ttl: 30 # ttl value when session granting a lease to register service

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
@ -71,6 +72,15 @@ func NewClient(ctx context.Context) (types.DataCoordClient, error) {
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetSession(sess)
if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() {
client.grpcClient.EnableEncryption()
cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "Datacoord")
if err != nil {
log.Error("Failed to create cert pool for Datacoord client")
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
}
return client, nil
}

View File

@ -174,7 +174,7 @@ func (s *Server) startGrpcLoop() {
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
}
s.grpcServer = grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
@ -201,7 +201,11 @@ func (s *Server) startGrpcLoop() {
}),
streamingserviceinterceptor.NewStreamingServiceStreamServerInterceptor(),
)),
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()))
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()),
}
grpcOpts = append(grpcOpts, utils.EnableInternalTLS("DataCoord"))
s.grpcServer = grpc.NewServer(grpcOpts...)
indexpb.RegisterIndexCoordServer(s.grpcServer, s)
datapb.RegisterDataCoordServer(s.grpcServer, s)
// register the streaming coord grpc service.

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
@ -72,6 +73,15 @@ func NewClient(ctx context.Context, addr string, serverID int64) (types.DataNode
client.grpcClient.SetNodeID(serverID)
client.grpcClient.SetSession(sess)
if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() {
client.grpcClient.EnableEncryption()
cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "DataNode")
if err != nil {
log.Error("Failed to create cert pool for DataNode client")
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
}
return client, nil
}

View File

@ -129,7 +129,7 @@ func (s *Server) startGrpcLoop() {
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
}
s.grpcServer = grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
@ -154,7 +154,11 @@ func (s *Server) startGrpcLoop() {
return s.serverID.Load()
}),
)),
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()))
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()),
}
grpcOpts = append(grpcOpts, utils.EnableInternalTLS("DataNode"))
s.grpcServer = grpc.NewServer(grpcOpts...)
datapb.RegisterDataNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx)

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/workerpb"
"github.com/milvus-io/milvus/internal/types"
@ -72,6 +73,15 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool)
if encryption {
client.grpcClient.EnableEncryption()
}
if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() {
client.grpcClient.EnableEncryption()
cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "IndexNode")
if err != nil {
log.Error("Failed to create cert pool for IndexNode client")
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
}
return client, nil
}

View File

@ -114,7 +114,7 @@ func (s *Server) startGrpcLoop() {
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
}
s.grpcServer = grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
@ -139,7 +139,11 @@ func (s *Server) startGrpcLoop() {
return s.serverID.Load()
}),
)),
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()))
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()),
}
grpcOpts = append(grpcOpts, utils.EnableInternalTLS("IndexNode"))
s.grpcServer = grpc.NewServer(grpcOpts...)
workerpb.RegisterIndexNodeServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
if err := s.grpcServer.Serve(s.listener); err != nil {

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/types"
@ -69,6 +70,15 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetNodeID(nodeID)
client.grpcClient.SetSession(sess)
if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() {
client.grpcClient.EnableEncryption()
cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "Proxy")
if err != nil {
log.Error("Failed to create cert pool for Proxy client")
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
}
return client, nil
}

View File

@ -342,7 +342,7 @@ func (s *Server) startInternalGrpc(errChan chan error) {
}
opts := tracer.GetInterceptorOpts()
s.grpcInternalServer = grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
@ -366,7 +366,12 @@ func (s *Server) startInternalGrpc(errChan chan error) {
}
return s.serverID.Load()
}),
)))
)),
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()),
}
grpcOpts = append(grpcOpts, utils.EnableInternalTLS("Proxy"))
s.grpcInternalServer = grpc.NewServer(grpcOpts...)
proxypb.RegisterProxyServer(s.grpcInternalServer, s)
grpc_health_v1.RegisterHealthServer(s.grpcInternalServer, s)
errChan <- nil

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
@ -63,6 +64,15 @@ func NewClient(ctx context.Context) (types.QueryCoordClient, error) {
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetSession(sess)
if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() {
client.grpcClient.EnableEncryption()
cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "QueryCoord")
if err != nil {
log.Error("Failed to create cert pool for QueryCoord client")
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
}
return client, nil
}

View File

@ -230,7 +230,7 @@ func (s *Server) startGrpcLoop() {
ctx, cancel := context.WithCancel(s.loopCtx)
defer cancel()
s.grpcServer = grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
@ -256,7 +256,10 @@ func (s *Server) startGrpcLoop() {
}),
)),
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()),
)
}
grpcOpts = append(grpcOpts, utils.EnableInternalTLS("QueryCoord"))
s.grpcServer = grpc.NewServer(grpcOpts...)
querypb.RegisterQueryCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
@ -37,6 +38,8 @@ import (
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
var Params *paramtable.ComponentParam = paramtable.Get()
// Client is the grpc client of QueryNode.
type Client struct {
grpcClient grpcclient.GrpcClient[querypb.QueryNodeClient]
@ -70,6 +73,15 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC
client.grpcClient.SetNodeID(nodeID)
client.grpcClient.SetSession(sess)
if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() {
client.grpcClient.EnableEncryption()
cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "QueryNode")
if err != nil {
log.Error("Failed to create cert pool for QueryNode client")
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
}
return client, nil
}

View File

@ -176,7 +176,7 @@ func (s *Server) startGrpcLoop() {
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
}
s.grpcServer = grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
@ -204,7 +204,10 @@ func (s *Server) startGrpcLoop() {
}),
)),
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()),
)
}
grpcOpts = append(grpcOpts, utils.EnableInternalTLS("QueryNode"))
s.grpcServer = grpc.NewServer(grpcOpts...)
querypb.RegisterQueryNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx)

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
@ -70,6 +71,15 @@ func NewClient(ctx context.Context) (types.RootCoordClient, error) {
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetSession(sess)
if Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool() {
client.grpcClient.EnableEncryption()
cp, err := utils.CreateCertPoolforClient(Params.InternalTLSCfg.InternalTLSCaPemPath.GetValue(), "RootCoord")
if err != nil {
log.Error("Failed to create cert pool for RootCoord client")
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
}
return client, nil
}

View File

@ -278,7 +278,7 @@ func (s *Server) startGrpcLoop() {
ctx, cancel := context.WithCancel(s.ctx)
defer cancel()
s.grpcServer = grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize.GetAsInt()),
@ -303,7 +303,11 @@ func (s *Server) startGrpcLoop() {
return s.serverID.Load()
}),
)),
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()))
grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler()),
}
grpcOpts = append(grpcOpts, utils.EnableInternalTLS("RootCoord"))
s.grpcServer = grpc.NewServer(grpcOpts...)
rootcoordpb.RegisterRootCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -1,9 +1,14 @@
package utils
import (
"crypto/x509"
"os"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
@ -30,3 +35,47 @@ func GracefulStopGRPCServer(s *grpc.Server) {
<-ch
}
}
func getTLSCreds(certFile string, keyFile string, nodeType string) credentials.TransportCredentials {
log.Info("TLS Server PEM Path", zap.String("path", certFile))
log.Info("TLS Server Key Path", zap.String("path", keyFile))
creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil {
log.Warn(nodeType+" can't create creds", zap.Error(err))
log.Warn(nodeType+" can't create creds", zap.Error(err))
}
return creds
}
func EnableInternalTLS(NodeType string) grpc.ServerOption {
var Params *paramtable.ComponentParam = paramtable.Get()
certFile := Params.InternalTLSCfg.InternalTLSServerPemPath.GetValue()
keyFile := Params.InternalTLSCfg.InternalTLSServerKeyPath.GetValue()
internaltlsEnabled := Params.InternalTLSCfg.InternalTLSEnabled.GetAsBool()
log.Info("Internal TLS Enabled", zap.Bool("value", internaltlsEnabled))
if internaltlsEnabled {
creds := getTLSCreds(certFile, keyFile, NodeType)
return grpc.Creds(creds)
}
return grpc.Creds(nil)
}
func CreateCertPoolforClient(caFile string, nodeType string) (*x509.CertPool, error) {
log.Info("Creating cert pool for " + nodeType)
log.Info("Cert file path:", zap.String("caFile", caFile))
certPool := x509.NewCertPool()
b, err := os.ReadFile(caFile)
if err != nil {
log.Error("Error reading cert file in client", zap.Error(err))
return nil, err
}
if !certPool.AppendCertsFromPEM(b) {
log.Error("credentials: failed to append certificates")
return nil, errors.New("failed to append certificates") // Cert pool is invalid, return nil and the error
}
return certPool, err
}

View File

@ -12,6 +12,8 @@ import (
mock "github.com/stretchr/testify/mock"
sessionutil "github.com/milvus-io/milvus/internal/util/sessionutil"
x509 "crypto/x509"
)
// MockGrpcClient is an autogenerated mock type for the GrpcClient type
@ -325,6 +327,39 @@ func (_c *MockGrpcClient_SetGetAddrFunc_Call[T]) RunAndReturn(run func(func() (s
return _c
}
// SetInternalTLSCertPool provides a mock function with given fields: cp
func (_m *MockGrpcClient[T]) SetInternalTLSCertPool(cp *x509.CertPool) {
_m.Called(cp)
}
// MockGrpcClient_SetInternalTLSCertPool_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetInternalTLSCertPool'
type MockGrpcClient_SetInternalTLSCertPool_Call[T grpcclient.GrpcComponent] struct {
*mock.Call
}
// SetInternalTLSCertPool is a helper method to define mock.On call
// - cp *x509.CertPool
func (_e *MockGrpcClient_Expecter[T]) SetInternalTLSCertPool(cp interface{}) *MockGrpcClient_SetInternalTLSCertPool_Call[T] {
return &MockGrpcClient_SetInternalTLSCertPool_Call[T]{Call: _e.mock.On("SetInternalTLSCertPool", cp)}
}
func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Run(run func(cp *x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*x509.CertPool))
})
return _c
}
func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) Return() *MockGrpcClient_SetInternalTLSCertPool_Call[T] {
_c.Call.Return()
return _c
}
func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(*x509.CertPool)) *MockGrpcClient_SetInternalTLSCertPool_Call[T] {
_c.Call.Return(run)
return _c
}
// SetNewGrpcClientFunc provides a mock function with given fields: _a0
func (_m *MockGrpcClient[T]) SetNewGrpcClientFunc(_a0 func(*grpc.ClientConn) T) {
_m.Called(_a0)

View File

@ -19,6 +19,7 @@ package grpcclient
import (
"context"
"crypto/tls"
"crypto/x509"
"strings"
"sync"
"time"
@ -84,6 +85,7 @@ type GrpcClient[T GrpcComponent] interface {
GetRole() string
SetGetAddrFunc(func() (string, error))
EnableEncryption()
SetInternalTLSCertPool(cp *x509.CertPool)
SetNewGrpcClientFunc(func(cc *grpc.ClientConn) T)
ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error)
Call(ctx context.Context, caller func(client T) (any, error)) (any, error)
@ -101,9 +103,10 @@ type ClientBase[T interface {
newGrpcClient func(cc *grpc.ClientConn) T
// grpcClient T
grpcClient *clientConnWrapper[T]
encryption bool
addr atomic.String
grpcClient *clientConnWrapper[T]
encryption bool
cpInternalTLS *x509.CertPool
addr atomic.String
// conn *grpc.ClientConn
grpcClientMtx sync.RWMutex
role string
@ -187,6 +190,10 @@ func (c *ClientBase[T]) EnableEncryption() {
c.encryption = true
}
func (c *ClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) {
c.cpInternalTLS = cp
}
// SetNewGrpcClientFunc sets newGrpcClient of client
func (c *ClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) {
c.newGrpcClient = f
@ -257,11 +264,12 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
compress = Zstd
}
if c.encryption {
log.Debug("Running in internalTLS mode with encryption enabled")
conn, err = grpc.DialContext(
dialContext,
addr,
// #nosec G402
grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})),
grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: c.cpInternalTLS})),
grpc.WithBlock(),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize),

View File

@ -18,6 +18,7 @@ package mock
import (
"context"
"crypto/x509"
"fmt"
"sync"
@ -37,6 +38,7 @@ type GRPCClientBase[T any] struct {
newGrpcClient func(cc *grpc.ClientConn) T
grpcClient T
cpInternalTLS *x509.CertPool
conn *grpc.ClientConn
grpcClientMtx sync.RWMutex
GetGrpcClientErr error
@ -60,6 +62,10 @@ func (c *GRPCClientBase[T]) SetRole(role string) {
func (c *GRPCClientBase[T]) EnableEncryption() {
}
func (c *GRPCClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) {
c.cpInternalTLS = cp
}
func (c *GRPCClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) {
c.newGrpcClient = f
}

View File

@ -64,9 +64,10 @@ func globalConfigPrefixs() []string {
return []string{"metastore", "localStorage", "etcd", "tikv", "minio", "pulsar", "kafka", "rocksmq", "log", "grpc", "common", "quotaAndLimits", "trace"}
}
// support read "milvus.yaml", "default.yaml", "user.yaml" as this order.
// order: milvus.yaml < default.yaml < user.yaml, do not change the order below
var defaultYaml = []string{"milvus.yaml", "default.yaml", "user.yaml"}
// support read "milvus.yaml", "_test.yaml", "default.yaml", "user.yaml" as this order.
// order: milvus.yaml < _test.yaml < default.yaml < user.yaml, do not change the order below.
// Use _test.yaml only for test related purpose.
var defaultYaml = []string{"milvus.yaml", "_test.yaml", "default.yaml", "user.yaml"}
// BaseTable the basics of paramtable
type BaseTable struct {

View File

@ -83,6 +83,8 @@ type ComponentParam struct {
RbacConfig rbacConfig
StreamingCfg streamingConfig
InternalTLSCfg InternalTLSConfig
RootCoordGrpcServerCfg GrpcServerConfig
ProxyGrpcServerCfg GrpcServerConfig
QueryCoordGrpcServerCfg GrpcServerConfig
@ -139,6 +141,8 @@ func (p *ComponentParam) init(bt *BaseTable) {
p.GpuConfig.init(bt)
p.KnowhereConfig.init(bt)
p.InternalTLSCfg.Init(bt)
p.RootCoordGrpcServerCfg.Init("rootCoord", bt)
p.ProxyGrpcServerCfg.Init("proxy", bt)
p.ProxyGrpcServerCfg.InternalPort.Export = true

View File

@ -535,3 +535,41 @@ func (p *GrpcClientConfig) GetDefaultRetryPolicy() map[string]interface{} {
"backoffMultiplier": p.BackoffMultiplier.GetAsFloat(),
}
}
type InternalTLSConfig struct {
InternalTLSEnabled ParamItem `refreshable:"false"`
InternalTLSServerPemPath ParamItem `refreshable:"false"`
InternalTLSServerKeyPath ParamItem `refreshable:"false"`
InternalTLSCaPemPath ParamItem `refreshable:"false"`
}
func (p *InternalTLSConfig) Init(base *BaseTable) {
p.InternalTLSEnabled = ParamItem{
Key: "common.security.internaltlsEnabled",
Version: "2.0.0",
DefaultValue: "false",
Export: true,
}
p.InternalTLSEnabled.Init(base.mgr)
p.InternalTLSServerPemPath = ParamItem{
Key: "internaltls.serverPemPath",
Version: "2.0.0",
Export: true,
}
p.InternalTLSServerPemPath.Init(base.mgr)
p.InternalTLSServerKeyPath = ParamItem{
Key: "internaltls.serverKeyPath",
Version: "2.0.0",
Export: true,
}
p.InternalTLSServerKeyPath.Init(base.mgr)
p.InternalTLSCaPemPath = ParamItem{
Key: "internaltls.caPemPath",
Version: "2.0.0",
Export: true,
}
p.InternalTLSCaPemPath.Init(base.mgr)
}

View File

@ -178,3 +178,19 @@ func TestGrpcClientParams(t *testing.T) {
assert.Equal(t, clientConfig.ServerKeyPath.GetValue(), "/key")
assert.Equal(t, clientConfig.CaPemPath.GetValue(), "/ca")
}
func TestInternalTLSParams(t *testing.T) {
base := ComponentParam{}
base.Init(NewBaseTable(SkipRemote(true)))
var internalTLSCfg InternalTLSConfig
internalTLSCfg.Init(base.baseTable)
base.Save("common.security.internalTlsEnabled", "true")
base.Save("internaltls.serverPemPath", "/pem")
base.Save("internaltls.serverKeyPath", "/key")
base.Save("internaltls.caPemPath", "/ca")
assert.Equal(t, internalTLSCfg.InternalTLSEnabled.GetAsBool(), true)
assert.Equal(t, internalTLSCfg.InternalTLSServerPemPath.GetValue(), "/pem")
assert.Equal(t, internalTLSCfg.InternalTLSServerKeyPath.GetValue(), "/key")
assert.Equal(t, internalTLSCfg.InternalTLSCaPemPath.GetValue(), "/ca")
}

View File

@ -0,0 +1,355 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internaltls
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
)
type InternaltlsTestSuit struct {
integration.MiniClusterSuite
indexType string
metricType string
vecType schemapb.DataType
}
// Define the content for the configuration YAML file
var configContent = `
rootCoord:
ip: localhost
proxy:
ip: localhost
queryCoord:
ip: localhost
queryNode:
ip: localhost
indexNode:
ip: localhost
dataCoord:
ip: localhost
dataNode:
ip: localhost
common:
security:
internaltlsEnabled : true
internaltls:
serverPemPath: ../../../configs/cert/server.pem
serverKeyPath: ../../../configs/cert/server.key
caPemPath: ../../../configs/cert/ca.pem
`
const configFilePath = "../../../configs/_test.yaml"
// CreateConfigFile creates the YAML configuration file for tests
func CreateConfigFile() {
// Write config content to _test.yaml file
err := os.WriteFile(configFilePath, []byte(configContent), 0o600)
if err != nil {
log.Error("Failed to create config file", zap.Error(err))
}
log.Info(fmt.Sprintf("Config file created: %s", configFilePath))
}
func (s *InternaltlsTestSuit) SetupSuite() {
log.Info("Initializing paramtable...")
CreateConfigFile()
paramtable.Init()
log.Info("Setting up EmbedEtcd...")
s.Require().NoError(s.SetupEmbedEtcd())
}
func (s *InternaltlsTestSuit) run() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dim = 128
dbName = ""
rowNum = 3000
)
collectionName := "TestHelloMilvus" + funcutil.GenRandomStr()
schema := integration.ConstructSchemaOfVecDataType(collectionName, dim, true, s.vecType)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
}
s.Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.Equal(showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
var fVecColumn *schemapb.FieldData
if s.vecType == schemapb.DataType_SparseFloatVector {
fVecColumn = integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
} else {
fVecColumn = integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
}
hashKeys := integration.GenerateHashKeys(rowNum)
insertCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("insert check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("insert report info", zap.Any("reportInfo", reportInfo))
s.Equal(hookutil.OpTypeInsert, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.RequestDataSizeKey])
return
}
}
}
go insertCheckReport()
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
// create index
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: fVecColumn.FieldName,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, s.indexType, s.metricType),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
s.WaitForIndexBuilt(ctx, collectionName, fVecColumn.FieldName)
// load
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.WaitForLoad(ctx, collectionName)
// search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(s.indexType, s.metricType)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
fVecColumn.FieldName, s.vecType, nil, s.metricType, params, nq, dim, topk, roundDecimal)
searchCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("search check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("search report info", zap.Any("reportInfo", reportInfo))
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.ResultDataSizeKey])
s.NotEqualValues(0, reportInfo[hookutil.RelatedDataSizeKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go searchCheckReport()
searchResult, err := c.Proxy.Search(ctx, searchReq)
err = merr.CheckRPCCall(searchResult, err)
s.NoError(err)
queryCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("query check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("query report info", zap.Any("reportInfo", reportInfo))
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
s.NotEqualValues(0, reportInfo[hookutil.ResultDataSizeKey])
s.NotEqualValues(0, reportInfo[hookutil.RelatedDataSizeKey])
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go queryCheckReport()
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
})
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
deleteCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for {
select {
case <-timeoutCtx.Done():
s.Fail("delete check timeout")
case report := <-c.Extension.GetReportChan():
reportInfo := report.(map[string]any)
log.Info("delete report info", zap.Any("reportInfo", reportInfo))
s.Equal(hookutil.OpTypeDelete, reportInfo[hookutil.OpTypeKey])
s.EqualValues(2, reportInfo[hookutil.SuccessCntKey])
s.EqualValues(0, reportInfo[hookutil.RelatedCntKey])
return
}
}
}
go deleteCheckReport()
deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: integration.Int64Field + " in [1, 2]",
})
if deleteResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("deleteResult fail reason", zap.String("reason", deleteResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, deleteResult.GetStatus().GetErrorCode())
status, err := c.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{
CollectionName: collectionName,
})
err = merr.CheckRPCCall(status, err)
s.NoError(err)
status, err = c.Proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{
CollectionName: collectionName,
})
err = merr.CheckRPCCall(status, err)
s.NoError(err)
log.Info("TestHelloMilvus succeed")
}
func (s *InternaltlsTestSuit) TestHelloMilvus_basic() {
log.Info("Under test Internal TLS hellomilvus...")
s.indexType = integration.IndexFaissIvfFlat
s.metricType = metric.L2
s.vecType = schemapb.DataType_FloatVector
s.run()
}
func (s *InternaltlsTestSuit) TearDownSuite() {
defer func() {
err := os.Remove(configFilePath)
if err != nil {
log.Error("Failed to delete config file:", zap.Error(err))
return
}
log.Info(fmt.Sprintf("Config file deleted: %s", configFilePath))
}()
s.MiniClusterSuite.TearDownSuite()
}
func TestInternalTLS(t *testing.T) {
suite.Run(t, new(InternaltlsTestSuit))
}