Prevent the client from closing grpc connection by mistake (#11918)

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
pull/12056/head
zhenshan.cao 2021-11-18 11:17:12 +08:00 committed by GitHub
parent 6d652263a2
commit 543c4891b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 350 additions and 103 deletions

View File

@ -27,6 +27,7 @@ import (
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
@ -192,14 +193,15 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("DataCoord Client grpc error", zap.Error(err))
c.resetConnection()
ret, err = caller()
if err == nil {
return ret, nil
}
return ret, err
}
@ -229,7 +231,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -245,7 +249,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
@ -261,7 +267,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -276,7 +284,9 @@ func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.Flush(ctx, req)
})
if err != nil || ret == nil {
@ -304,7 +314,9 @@ func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AssignSegmentID(ctx, req)
})
if err != nil || ret == nil {
@ -328,7 +340,9 @@ func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentSta
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentStates(ctx, req)
})
if err != nil || ret == nil {
@ -351,7 +365,9 @@ func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsert
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetInsertBinlogPaths(ctx, req)
})
if err != nil || ret == nil {
@ -374,7 +390,9 @@ func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCol
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetCollectionStatistics(ctx, req)
})
if err != nil || ret == nil {
@ -397,7 +415,9 @@ func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetPartitionStatistics(ctx, req)
})
if err != nil || ret == nil {
@ -414,7 +434,9 @@ func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringRes
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{})
})
if err != nil || ret == nil {
@ -436,7 +458,9 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfo(ctx, req)
})
if err != nil || ret == nil {
@ -483,7 +507,9 @@ func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetRecoveryInfo(ctx, req)
})
if err != nil || ret == nil {
@ -506,7 +532,9 @@ func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetFlushedSegments(ctx, req)
})
if err != nil || ret == nil {
@ -522,7 +550,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
@ -537,7 +567,9 @@ func (c *Client) CompleteCompaction(ctx context.Context, req *datapb.CompactionR
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CompleteCompaction(ctx, req)
})
if err != nil || ret == nil {
@ -552,7 +584,9 @@ func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompa
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ManualCompaction(ctx, req)
})
if err != nil || ret == nil {
@ -567,7 +601,9 @@ func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompac
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetCompactionState(ctx, req)
})
if err != nil || ret == nil {
@ -582,7 +618,9 @@ func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetCompactionStateWithPlans(ctx, req)
})
if err != nil || ret == nil {
@ -597,7 +635,9 @@ func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchChannels(ctx, req)
})
if err != nil || ret == nil {

View File

@ -22,16 +22,16 @@ import (
"sync"
"time"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/retry"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"google.golang.org/grpc/codes"
@ -174,6 +174,10 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("DataNode Client grpc error", zap.Error(err))
c.resetConnection()
@ -214,7 +218,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -229,7 +235,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -244,7 +252,9 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchDmChannels(ctx, req)
})
if err != nil || ret == nil {
@ -267,7 +277,9 @@ func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsReq
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.FlushSegments(ctx, req)
})
if err != nil || ret == nil {
@ -282,7 +294,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {
@ -297,7 +311,9 @@ func (c *Client) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*c
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.Compaction(ctx, req)
})
if err != nil || ret == nil {

View File

@ -22,17 +22,17 @@ import (
"sync"
"time"
"google.golang.org/grpc"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
@ -181,6 +181,10 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("IndexCoord Client grpc error", zap.Error(err))
c.resetConnection()
@ -220,7 +224,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -236,7 +242,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
@ -252,7 +260,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -268,7 +278,9 @@ func (c *Client) BuildIndex(ctx context.Context, req *indexpb.BuildIndexRequest)
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.BuildIndex(ctx, req)
})
if err != nil || ret == nil {
@ -284,7 +296,9 @@ func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropIndex(ctx, req)
})
if err != nil || ret == nil {
@ -300,7 +314,9 @@ func (c *Client) GetIndexStates(ctx context.Context, req *indexpb.GetIndexStates
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetIndexStates(ctx, req)
})
if err != nil || ret == nil {
@ -316,7 +332,9 @@ func (c *Client) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFil
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetIndexFilePaths(ctx, req)
})
if err != nil || ret == nil {
@ -332,7 +350,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {

View File

@ -27,6 +27,7 @@ import (
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap"
@ -169,6 +170,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("IndexNode Client grpc error", zap.Error(err))
c.resetConnection()
@ -208,7 +212,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -224,7 +230,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
@ -240,7 +248,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -256,7 +266,9 @@ func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateIndex(ctx, req)
})
if err != nil || ret == nil {
@ -272,7 +284,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap"
@ -171,6 +172,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("Proxy Client grpc error", zap.Error(err))
c.resetConnection()
@ -208,7 +212,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -223,7 +229,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -238,7 +246,9 @@ func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.InvalidateCollectionMetaCache(ctx, req)
})
if err != nil || ret == nil {
@ -253,7 +263,9 @@ func (c *Client) ReleaseDQLMessageStream(ctx context.Context, req *proxypb.Relea
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseDQLMessageStream(ctx, req)
})
if err != nil || ret == nil {

View File

@ -25,6 +25,7 @@ import (
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
@ -195,6 +196,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("QueryCoord Client grpc error", zap.Error(err))
c.resetConnection()
@ -234,7 +238,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -250,7 +256,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
@ -266,7 +274,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -282,7 +292,9 @@ func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowCollections(ctx, req)
})
if err != nil || ret == nil {
@ -298,7 +310,9 @@ func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollection
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadCollection(ctx, req)
})
if err != nil || ret == nil {
@ -314,7 +328,9 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseCollection(ctx, req)
})
if err != nil || ret == nil {
@ -330,7 +346,9 @@ func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowPartitions(ctx, req)
})
if err != nil || ret == nil {
@ -346,7 +364,9 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadPartitions(ctx, req)
})
if err != nil || ret == nil {
@ -362,7 +382,9 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleasePartitions(ctx, req)
})
if err != nil || ret == nil {
@ -378,7 +400,9 @@ func (c *Client) CreateQueryChannel(ctx context.Context, req *querypb.CreateQuer
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateQueryChannel(ctx, req)
})
if err != nil || ret == nil {
@ -394,7 +418,9 @@ func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetPartitionStates(ctx, req)
})
if err != nil || ret == nil {
@ -410,7 +436,9 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfo(ctx, req)
})
if err != nil || ret == nil {
@ -426,7 +454,9 @@ func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadBalance(ctx, req)
})
if err != nil || ret == nil {
@ -442,7 +472,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
)
@ -161,6 +162,9 @@ func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error)
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("QueryNode Client grpc error", zap.Error(err))
c.resetConnection()
@ -200,7 +204,9 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -216,7 +222,9 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
@ -232,7 +240,9 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -248,7 +258,9 @@ func (c *Client) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChann
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AddQueryChannel(ctx, req)
})
if err != nil || ret == nil {
@ -264,7 +276,9 @@ func (c *Client) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQuer
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.RemoveQueryChannel(ctx, req)
})
if err != nil || ret == nil {
@ -280,7 +294,9 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChanne
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchDmChannels(ctx, req)
})
if err != nil || ret == nil {
@ -296,7 +312,9 @@ func (c *Client) WatchDeltaChannels(ctx context.Context, req *querypb.WatchDelta
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.WatchDeltaChannels(ctx, req)
})
if err != nil || ret == nil {
@ -312,7 +330,9 @@ func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequ
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadSegments(ctx, req)
})
if err != nil || ret == nil {
@ -328,7 +348,9 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseCollection(ctx, req)
})
if err != nil || ret == nil {
@ -344,7 +366,9 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleasePartitions(ctx, req)
})
if err != nil || ret == nil {
@ -360,7 +384,9 @@ func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmen
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseSegments(ctx, req)
})
if err != nil || ret == nil {
@ -376,7 +402,9 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetSegmentInfo(ctx, req)
})
if err != nil || ret == nil {
@ -392,7 +420,9 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, req)
})
if err != nil || ret == nil {

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
@ -223,6 +224,9 @@ func (c *GrpcClient) recall(caller func() (interface{}, error)) (interface{}, er
if err == nil {
return ret, nil
}
if err == context.Canceled || err == context.DeadlineExceeded {
return nil, err
}
log.Debug("RootCoord Client grpc error", zap.Error(err))
c.resetConnection()
@ -241,7 +245,9 @@ func (c *GrpcClient) GetComponentStates(ctx context.Context) (*internalpb.Compon
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
})
if err != nil || ret == nil {
@ -257,7 +263,9 @@ func (c *GrpcClient) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRe
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
})
if err != nil || ret == nil {
@ -273,7 +281,9 @@ func (c *GrpcClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.String
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
})
if err != nil || ret == nil {
@ -289,7 +299,9 @@ func (c *GrpcClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCo
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateCollection(ctx, in)
})
if err != nil || ret == nil {
@ -305,7 +317,9 @@ func (c *GrpcClient) DropCollection(ctx context.Context, in *milvuspb.DropCollec
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropCollection(ctx, in)
})
if err != nil || ret == nil {
@ -321,7 +335,9 @@ func (c *GrpcClient) HasCollection(ctx context.Context, in *milvuspb.HasCollecti
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.HasCollection(ctx, in)
})
if err != nil || ret == nil {
@ -337,7 +353,9 @@ func (c *GrpcClient) DescribeCollection(ctx context.Context, in *milvuspb.Descri
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DescribeCollection(ctx, in)
})
if err != nil || ret == nil {
@ -353,7 +371,9 @@ func (c *GrpcClient) ShowCollections(ctx context.Context, in *milvuspb.ShowColle
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowCollections(ctx, in)
})
if err != nil || ret == nil {
@ -369,7 +389,9 @@ func (c *GrpcClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePar
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreatePartition(ctx, in)
})
if err != nil || ret == nil {
@ -385,7 +407,9 @@ func (c *GrpcClient) DropPartition(ctx context.Context, in *milvuspb.DropPartiti
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropPartition(ctx, in)
})
if err != nil || ret == nil {
@ -401,7 +425,9 @@ func (c *GrpcClient) HasPartition(ctx context.Context, in *milvuspb.HasPartition
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.HasPartition(ctx, in)
})
if err != nil || ret == nil {
@ -417,7 +443,9 @@ func (c *GrpcClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartit
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowPartitions(ctx, in)
})
if err != nil || ret == nil {
@ -433,7 +461,9 @@ func (c *GrpcClient) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRe
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateIndex(ctx, in)
})
if err != nil || ret == nil {
@ -449,7 +479,9 @@ func (c *GrpcClient) DropIndex(ctx context.Context, in *milvuspb.DropIndexReques
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropIndex(ctx, in)
})
if err != nil || ret == nil {
@ -465,7 +497,9 @@ func (c *GrpcClient) DescribeIndex(ctx context.Context, in *milvuspb.DescribeInd
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DescribeIndex(ctx, in)
})
if err != nil || ret == nil {
@ -481,7 +515,9 @@ func (c *GrpcClient) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTi
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AllocTimestamp(ctx, in)
})
if err != nil || ret == nil {
@ -497,7 +533,9 @@ func (c *GrpcClient) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AllocID(ctx, in)
})
if err != nil || ret == nil {
@ -513,7 +551,9 @@ func (c *GrpcClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.C
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.UpdateChannelTimeTick(ctx, in)
})
if err != nil || ret == nil {
@ -529,7 +569,9 @@ func (c *GrpcClient) DescribeSegment(ctx context.Context, in *milvuspb.DescribeS
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DescribeSegment(ctx, in)
})
if err != nil || ret == nil {
@ -545,7 +587,9 @@ func (c *GrpcClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegments
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ShowSegments(ctx, in)
})
if err != nil || ret == nil {
@ -561,7 +605,9 @@ func (c *GrpcClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.Re
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.ReleaseDQLMessageStream(ctx, in)
})
if err != nil || ret == nil {
@ -577,7 +623,9 @@ func (c *GrpcClient) SegmentFlushCompleted(ctx context.Context, in *datapb.Segme
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.SegmentFlushCompleted(ctx, in)
})
if err != nil || ret == nil {
@ -593,7 +641,9 @@ func (c *GrpcClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequ
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetMetrics(ctx, in)
})
if err != nil || ret == nil {
@ -609,7 +659,9 @@ func (c *GrpcClient) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasR
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.CreateAlias(ctx, req)
})
if err != nil || ret == nil {
@ -625,7 +677,9 @@ func (c *GrpcClient) DropAlias(ctx context.Context, req *milvuspb.DropAliasReque
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.DropAlias(ctx, req)
})
if err != nil || ret == nil {
@ -641,7 +695,9 @@ func (c *GrpcClient) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasReq
if err != nil {
return nil, err
}
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.AlterAlias(ctx, req)
})
if err != nil || ret == nil {

View File

@ -184,3 +184,7 @@ func GetAttrByKeyFromRepeatedKV(key string, kvs []*commonpb.KeyValuePair) (strin
return "", errors.New("key " + key + " not found")
}
func CheckCtxValid(ctx context.Context) bool {
return ctx.Err() != context.DeadlineExceeded && ctx.Err() != context.Canceled
}

View File

@ -264,3 +264,26 @@ func TestGetAttrByKeyFromRepeatedKV(t *testing.T) {
assert.Equal(t, test.errIsNil, err == nil)
}
}
func TestCheckCtxValid(t *testing.T) {
bgCtx := context.Background()
timeout := 20 * time.Millisecond
deltaTime := 5 * time.Millisecond
ctx1, cancel1 := context.WithTimeout(bgCtx, timeout)
defer cancel1()
assert.True(t, CheckCtxValid(ctx1))
time.Sleep(timeout + deltaTime)
assert.False(t, CheckCtxValid(ctx1))
ctx2, cancel2 := context.WithTimeout(bgCtx, timeout)
assert.True(t, CheckCtxValid(ctx2))
cancel2()
assert.False(t, CheckCtxValid(ctx2))
futureTime := time.Now().Add(timeout)
ctx3, cancel3 := context.WithDeadline(bgCtx, futureTime)
defer cancel3()
assert.True(t, CheckCtxValid(ctx3))
time.Sleep(timeout + deltaTime)
assert.False(t, CheckCtxValid(ctx3))
}