From 2146af1fb22e18542fb2dfcb26ba660e4f3b0716 Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Tue, 10 Jan 2023 20:35:39 +0800 Subject: [PATCH] Return insufficient memory error when load failed (#21574) Signed-off-by: bigsheeper --- internal/proxy/error.go | 12 ++ internal/proxy/error_test.go | 12 ++ internal/proxy/impl.go | 16 +++ internal/proxy/proxy_test.go | 27 ++++ internal/proxy/util.go | 10 ++ internal/querycoordv2/job/job.go | 4 + internal/querycoordv2/job/job_test.go | 1 + .../querycoordv2/meta/failed_load_cache.go | 115 ++++++++++++++++++ .../meta/failed_load_cache_test.go | 55 +++++++++ internal/querycoordv2/server.go | 3 + internal/querycoordv2/services.go | 20 +++ internal/querycoordv2/services_test.go | 40 ++++++ internal/querycoordv2/task/executor.go | 11 ++ internal/querycoordv2/task/scheduler.go | 19 ++- internal/querycoordv2/task/task_test.go | 1 + internal/querynode/errors.go | 2 + internal/querynode/impl.go | 3 + internal/querynode/impl_utils.go | 9 ++ internal/querynode/impl_utils_test.go | 33 +++++ internal/querynode/segment_loader.go | 3 +- internal/querynode/shard_cluster.go | 4 + 21 files changed, 397 insertions(+), 3 deletions(-) create mode 100644 internal/querycoordv2/meta/failed_load_cache.go create mode 100644 internal/querycoordv2/meta/failed_load_cache_test.go diff --git a/internal/proxy/error.go b/internal/proxy/error.go index 1f4af9cded..10120540c8 100644 --- a/internal/proxy/error.go +++ b/internal/proxy/error.go @@ -23,11 +23,23 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" ) // TODO(dragondriver): add more common error type +// ErrInsufficientMemory returns insufficient memory error. +var ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad") + +// InSufficientMemoryStatus returns insufficient memory status. +func InSufficientMemoryStatus(collectionName string) *commonpb.Status { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad, + Reason: fmt.Sprintf("deny to load, insufficient memory, please allocate more resources, collectionName: %s", collectionName), + } +} + func errInvalidNumRows(numRows uint32) error { return fmt.Errorf("invalid num_rows: %d", numRows) } diff --git a/internal/proxy/error_test.go b/internal/proxy/error_test.go index 7cf7621a7d..35ecdd4e20 100644 --- a/internal/proxy/error_test.go +++ b/internal/proxy/error_test.go @@ -17,10 +17,14 @@ package proxy import ( + "errors" + "fmt" "testing" + "github.com/stretchr/testify/assert" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/log" ) @@ -150,3 +154,11 @@ func Test_errProxyIsUnhealthy(t *testing.T) { zap.Error(errProxyIsUnhealthy(id))) } } + +func Test_ErrInsufficientMemory(t *testing.T) { + err := fmt.Errorf("%w, mock insufficient memory error", ErrInsufficientMemory) + assert.True(t, errors.Is(err, ErrInsufficientMemory)) + + status := InSufficientMemoryStatus("collection1") + assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, status.GetErrorCode()) +} diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 446f6a68ca..c218358bcd 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -18,6 +18,7 @@ package proxy import ( "context" + "errors" "fmt" "os" "strconv" @@ -1445,6 +1446,11 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get zap.Strings("partition_name", request.PartitionNames), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() + if errors.Is(err, ErrInsufficientMemory) { + return &milvuspb.GetLoadingProgressResponse{ + Status: InSufficientMemoryStatus(request.GetCollectionName()), + } + } return &milvuspb.GetLoadingProgressResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -1574,12 +1580,22 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt var progress int64 if len(request.GetPartitionNames()) == 0 { if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil { + if errors.Is(err, ErrInsufficientMemory) { + return &milvuspb.GetLoadStateResponse{ + Status: InSufficientMemoryStatus(request.GetCollectionName()), + }, nil + } successResponse.State = commonpb.LoadState_LoadStateNotLoad return successResponse, nil } } else { if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(), request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil { + if errors.Is(err, ErrInsufficientMemory) { + return &milvuspb.GetLoadStateResponse{ + Status: InSufficientMemoryStatus(request.GetCollectionName()), + }, nil + } successResponse.State = commonpb.LoadState_LoadStateNotLoad return successResponse, nil } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index d6e5184b11..31860770a8 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -4299,4 +4299,31 @@ func TestProxy_GetLoadState(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) assert.Equal(t, int64(50), progressResp.Progress) } + + t.Run("test insufficient memory", func(t *testing.T) { + q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad}, + }, nil + }), SetQueryCoordShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad}, + }, nil + })) + q.state.Store(commonpb.StateCode_Healthy) + proxy := &Proxy{queryCoord: q} + proxy.stateCode.Store(commonpb.StateCode_Healthy) + + stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, stateResp.Status.ErrorCode) + + progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.Status.ErrorCode) + + progressResp, err = proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo", PartitionNames: []string{"p1"}}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.Status.ErrorCode) + }) } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 8849e9d52d..04496ce663 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1002,6 +1002,11 @@ func getCollectionProgress(ctx context.Context, queryCoord types.QueryCoord, return 0, err } + if resp.Status.ErrorCode == commonpb.ErrorCode_InsufficientMemoryToLoad { + log.Warn("detected insufficientMemoryError when getCollectionProgress", zap.Int64("collection_id", collectionID), zap.String("reason", resp.GetStatus().GetReason())) + return 0, ErrInsufficientMemory + } + if resp.Status.ErrorCode != commonpb.ErrorCode_Success { log.Warn("fail to show collections", zap.Int64("collection_id", collectionID), zap.String("reason", resp.Status.Reason)) @@ -1043,6 +1048,11 @@ func getPartitionProgress(ctx context.Context, queryCoord types.QueryCoord, zap.Error(err)) return 0, err } + if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_InsufficientMemoryToLoad { + log.Warn("detected insufficientMemoryError when getPartitionProgress", zap.Int64("collection_id", collectionID), + zap.String("collection_name", collectionName), zap.Strings("partition_names", partitionNames), zap.String("reason", resp.GetStatus().GetReason())) + return 0, ErrInsufficientMemory + } if len(resp.InMemoryPercentages) != len(partitionIDs) { errMsg := "fail to show partitions from the querycoord, invalid data num" log.Warn(errMsg, zap.Int64("collection_id", collectionID), diff --git a/internal/querycoordv2/job/job.go b/internal/querycoordv2/job/job.go index 8cb24dd101..e4c9e22e7f 100644 --- a/internal/querycoordv2/job/job.go +++ b/internal/querycoordv2/job/job.go @@ -186,6 +186,8 @@ func (job *LoadCollectionJob) Execute() error { zap.Int64("collectionID", req.GetCollectionID()), ) + meta.GlobalFailedLoadCache.Remove(req.GetCollectionID()) + // Clear stale replicas err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID()) if err != nil { @@ -393,6 +395,8 @@ func (job *LoadPartitionJob) Execute() error { zap.Int64s("partitionIDs", req.GetPartitionIDs()), ) + meta.GlobalFailedLoadCache.Remove(req.GetCollectionID()) + // Clear stale replicas err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID()) if err != nil { diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index e3a4d02846..58462ebd5a 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -136,6 +136,7 @@ func (suite *JobSuite) SetupTest() { suite.scheduler = NewScheduler() suite.scheduler.Start(context.Background()) + meta.GlobalFailedLoadCache = meta.NewFailedLoadCache() } func (suite *JobSuite) TearDownTest() { diff --git a/internal/querycoordv2/meta/failed_load_cache.go b/internal/querycoordv2/meta/failed_load_cache.go new file mode 100644 index 0000000000..1954e8a565 --- /dev/null +++ b/internal/querycoordv2/meta/failed_load_cache.go @@ -0,0 +1,115 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package meta + +import ( + "sync" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus/internal/log" + . "github.com/milvus-io/milvus/internal/util/typeutil" +) + +const expireTime = 24 * time.Hour + +var GlobalFailedLoadCache *FailedLoadCache + +type failInfo struct { + count int + err error + lastTime time.Time +} + +type FailedLoadCache struct { + mu sync.RWMutex + records map[UniqueID]map[commonpb.ErrorCode]*failInfo +} + +func NewFailedLoadCache() *FailedLoadCache { + return &FailedLoadCache{ + records: make(map[UniqueID]map[commonpb.ErrorCode]*failInfo), + } +} + +func (l *FailedLoadCache) Get(collectionID UniqueID) *commonpb.Status { + l.mu.RLock() + defer l.mu.RUnlock() + status := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + if _, ok := l.records[collectionID]; !ok { + return status + } + if len(l.records[collectionID]) == 0 { + return status + } + var max = 0 + for code, info := range l.records[collectionID] { + if info.count > max { + max = info.count + status.ErrorCode = code + status.Reason = info.err.Error() + } + } + log.Warn("FailedLoadCache hits failed record", zap.Int64("collectionID", collectionID), + zap.String("errCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason())) + return status +} + +func (l *FailedLoadCache) Put(collectionID UniqueID, errCode commonpb.ErrorCode, err error) { + if errCode == commonpb.ErrorCode_Success { + return + } + + l.mu.Lock() + defer l.mu.Unlock() + if _, ok := l.records[collectionID]; !ok { + l.records[collectionID] = make(map[commonpb.ErrorCode]*failInfo) + } + if _, ok := l.records[collectionID][errCode]; !ok { + l.records[collectionID][errCode] = &failInfo{} + } + l.records[collectionID][errCode].count++ + l.records[collectionID][errCode].err = err + l.records[collectionID][errCode].lastTime = time.Now() + log.Warn("FailedLoadCache put failed record", zap.Int64("collectionID", collectionID), + zap.String("errCode", errCode.String()), zap.Error(err)) +} + +func (l *FailedLoadCache) Remove(collectionID UniqueID) { + l.mu.Lock() + defer l.mu.Unlock() + delete(l.records, collectionID) + log.Info("FailedLoadCache removes cache", zap.Int64("collectionID", collectionID)) +} + +func (l *FailedLoadCache) TryExpire() { + l.mu.Lock() + defer l.mu.Unlock() + for col, infos := range l.records { + for code, info := range infos { + if time.Since(info.lastTime) > expireTime { + delete(l.records[col], code) + } + } + if len(l.records[col]) == 0 { + delete(l.records, col) + log.Info("FailedLoadCache expires cache", zap.Int64("collectionID", col)) + } + } +} diff --git a/internal/querycoordv2/meta/failed_load_cache_test.go b/internal/querycoordv2/meta/failed_load_cache_test.go new file mode 100644 index 0000000000..af6a7819da --- /dev/null +++ b/internal/querycoordv2/meta/failed_load_cache_test.go @@ -0,0 +1,55 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package meta + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" +) + +func TestFailedLoadCache(t *testing.T) { + GlobalFailedLoadCache = NewFailedLoadCache() + + colID := int64(0) + errCode := commonpb.ErrorCode_InsufficientMemoryToLoad + mockErr := fmt.Errorf("mock insufficient memory reason") + + GlobalFailedLoadCache.Put(colID, commonpb.ErrorCode_Success, nil) + res := GlobalFailedLoadCache.Get(colID) + assert.Equal(t, commonpb.ErrorCode_Success, res.GetErrorCode()) + + GlobalFailedLoadCache.Put(colID, errCode, mockErr) + res = GlobalFailedLoadCache.Get(colID) + assert.Equal(t, errCode, res.GetErrorCode()) + + GlobalFailedLoadCache.Remove(colID) + res = GlobalFailedLoadCache.Get(colID) + assert.Equal(t, commonpb.ErrorCode_Success, res.GetErrorCode()) + + GlobalFailedLoadCache.Put(colID, errCode, mockErr) + GlobalFailedLoadCache.mu.Lock() + GlobalFailedLoadCache.records[colID][errCode].lastTime = time.Now().Add(-expireTime * 2) + GlobalFailedLoadCache.mu.Unlock() + GlobalFailedLoadCache.TryExpire() + res = GlobalFailedLoadCache.Get(colID) + assert.Equal(t, commonpb.ErrorCode_Success, res.GetErrorCode()) +} diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 3efe17035b..85ee60de82 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -234,6 +234,9 @@ func (s *Server) Init() error { // Init observers s.initObserver() + // Init load status cache + meta.GlobalFailedLoadCache = meta.NewFailedLoadCache() + log.Info("QueryCoord init success") return err } diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 968d15b7b4..999aec8b4e 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -56,6 +56,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, ErrNotHealthy), }, nil } + defer meta.GlobalFailedLoadCache.TryExpire() isGetAll := false collectionSet := typeutil.NewUniqueSet(req.GetCollectionIDs()...) @@ -86,6 +87,13 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio // ignore it continue } + status := meta.GlobalFailedLoadCache.Get(collectionID) + if status.ErrorCode != commonpb.ErrorCode_Success { + log.Warn("show collection failed", zap.String("errCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason())) + return &querypb.ShowCollectionsResponse{ + Status: status, + }, nil + } err := fmt.Errorf("collection %d has not been loaded to memory or load failed", collectionID) log.Warn("show collection failed", zap.Error(err)) return &querypb.ShowCollectionsResponse{ @@ -114,6 +122,7 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, ErrNotHealthy), }, nil } + defer meta.GlobalFailedLoadCache.TryExpire() // TODO(yah01): now, for load collection, the percentage of partition is equal to the percentage of collection, // we can calculates the real percentage of partitions @@ -163,6 +172,13 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions } if isReleased { + status := meta.GlobalFailedLoadCache.Get(req.GetCollectionID()) + if status.ErrorCode != commonpb.ErrorCode_Success { + log.Warn("show collection failed", zap.String("errCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason())) + return &querypb.ShowPartitionsResponse{ + Status: status, + }, nil + } msg := fmt.Sprintf("collection %v has not been loaded into QueryNode", req.GetCollectionID()) log.Warn(msg) return &querypb.ShowPartitionsResponse{ @@ -251,6 +267,8 @@ func (s *Server) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl log.Info("collection released") metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc() metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds())) + meta.GlobalFailedLoadCache.Remove(req.GetCollectionID()) + return successStatus, nil } @@ -333,6 +351,8 @@ func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc() metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds())) + + meta.GlobalFailedLoadCache.Remove(req.GetCollectionID()) return successStatus, nil } diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index f465f9da31..5f1c2837b7 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -19,6 +19,7 @@ package querycoordv2 import ( "context" "encoding/json" + "fmt" "testing" "time" @@ -141,6 +142,7 @@ func (suite *ServiceSuite) SetupTest() { suite.meta, suite.targetMgr, ) + meta.GlobalFailedLoadCache = meta.NewFailedLoadCache() suite.server = &Server{ kv: suite.kv, @@ -185,6 +187,18 @@ func (suite *ServiceSuite) TestShowCollections() { suite.Len(resp.CollectionIDs, 1) suite.Equal(collection, resp.CollectionIDs[0]) + // Test insufficient memory + colBak := suite.meta.CollectionManager.GetCollection(collection) + err = suite.meta.CollectionManager.RemoveCollection(collection) + suite.NoError(err) + meta.GlobalFailedLoadCache.Put(collection, commonpb.ErrorCode_InsufficientMemoryToLoad, fmt.Errorf("mock insufficient memory reason")) + resp, err = server.ShowCollections(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) + meta.GlobalFailedLoadCache.Remove(collection) + err = suite.meta.CollectionManager.PutCollection(colBak) + suite.NoError(err) + // Test when server is not healthy server.UpdateStateCode(commonpb.StateCode_Initializing) resp, err = server.ShowCollections(ctx, req) @@ -225,6 +239,32 @@ func (suite *ServiceSuite) TestShowPartitions() { for _, partition := range partitions[0:1] { suite.Contains(resp.PartitionIDs, partition) } + + // Test insufficient memory + if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { + colBak := suite.meta.CollectionManager.GetCollection(collection) + err = suite.meta.CollectionManager.RemoveCollection(collection) + suite.NoError(err) + meta.GlobalFailedLoadCache.Put(collection, commonpb.ErrorCode_InsufficientMemoryToLoad, fmt.Errorf("mock insufficient memory reason")) + resp, err = server.ShowPartitions(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) + meta.GlobalFailedLoadCache.Remove(collection) + err = suite.meta.CollectionManager.PutCollection(colBak) + suite.NoError(err) + } else { + partitionID := partitions[0] + parBak := suite.meta.CollectionManager.GetPartition(partitionID) + err = suite.meta.CollectionManager.RemovePartition(partitionID) + suite.NoError(err) + meta.GlobalFailedLoadCache.Put(collection, commonpb.ErrorCode_InsufficientMemoryToLoad, fmt.Errorf("mock insufficient memory reason")) + resp, err = server.ShowPartitions(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) + meta.GlobalFailedLoadCache.Remove(collection) + err = suite.meta.CollectionManager.PutPartition(parBak) + suite.NoError(err) + } } // Test when server is not healthy diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index a855003997..068ae6ac32 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -19,6 +19,7 @@ package task import ( "context" "errors" + "fmt" "sync" "time" @@ -146,8 +147,12 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) { action := task.Actions()[mergeTask.steps[0]] defer func() { + canceled := task.canceled.Load() for i := range mergeTask.tasks { mergeTask.tasks[i].SetErr(task.Err()) + if canceled { + mergeTask.tasks[i].Cancel() + } ex.removeTask(mergeTask.tasks[i], mergeTask.steps[i]) } }() @@ -184,6 +189,12 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) { log.Warn("failed to load segment, it may be a false failure", zap.Error(err)) return } + if status.ErrorCode == commonpb.ErrorCode_InsufficientMemoryToLoad { + log.Warn("insufficient memory to load segment", zap.String("err", status.GetReason())) + task.SetErr(fmt.Errorf("%w, err:%s", ErrInsufficientMemory, status.GetReason())) + task.Cancel() + return + } if status.ErrorCode != commonpb.ErrorCode_Success { log.Warn("failed to load segment", zap.String("reason", status.GetReason())) return diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 193dba617d..0d1db9d558 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -23,6 +23,7 @@ import ( "runtime" "sync" + "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -53,8 +54,8 @@ var ( // or the target channel is not in TargetManager ErrTaskStale = errors.New("TaskStale") - // No enough memory to load segment - ErrResourceNotEnough = errors.New("ResourceNotEnough") + // ErrInsufficientMemory returns insufficient memory error. + ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad") ErrFailedResponse = errors.New("RpcFailed") ErrTaskAlreadyDone = errors.New("TaskAlreadyDone") @@ -658,6 +659,16 @@ func (scheduler *taskScheduler) RemoveByNode(node int64) { } } +func (scheduler *taskScheduler) recordSegmentTaskError(task *SegmentTask) { + var errCode commonpb.ErrorCode + if errors.Is(task.Err(), ErrInsufficientMemory) { + errCode = commonpb.ErrorCode_InsufficientMemoryToLoad + } else { + errCode = commonpb.ErrorCode_UnexpectedError + } + meta.GlobalFailedLoadCache.Put(task.collectionID, errCode, task.Err()) +} + func (scheduler *taskScheduler) remove(task Task) { log := log.With( zap.Int64("taskID", task.ID()), @@ -675,6 +686,10 @@ func (scheduler *taskScheduler) remove(task Task) { index := NewReplicaSegmentIndex(task) delete(scheduler.segmentTasks, index) log = log.With(zap.Int64("segmentID", task.SegmentID())) + if task.Err() != nil { + log.Warn("task scheduler recordSegmentTaskError", zap.Error(task.err)) + scheduler.recordSegmentTaskError(task) + } case *ChannelTask: index := replicaChannelIndex{task.ReplicaID(), task.Channel()} diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 17c24dc150..66383edb53 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -142,6 +142,7 @@ func (suite *TaskSuite) SetupTest() { suite.scheduler.AddExecutor(1) suite.scheduler.AddExecutor(2) suite.scheduler.AddExecutor(3) + meta.GlobalFailedLoadCache = meta.NewFailedLoadCache() } func (suite *TaskSuite) BeforeTest(suiteName, testName string) { diff --git a/internal/querynode/errors.go b/internal/querynode/errors.go index 210d90d152..1f044f790f 100644 --- a/internal/querynode/errors.go +++ b/internal/querynode/errors.go @@ -27,6 +27,8 @@ var ( ErrShardNotAvailable = errors.New("ShardNotAvailable") // ErrTsLagTooLarge serviceable and guarantee lag too large. ErrTsLagTooLarge = errors.New("Timestamp lag too large") + // ErrInsufficientMemory returns insufficient memory error. + ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad") ) // WrapErrShardNotAvailable wraps ErrShardNotAvailable with replica id and channel name. diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 6f657a2fbf..d4bb0540dc 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -520,6 +520,9 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error(), } + if errors.Is(err, ErrInsufficientMemory) { + status.ErrorCode = commonpb.ErrorCode_InsufficientMemoryToLoad + } log.Warn(err.Error()) return status, nil } diff --git a/internal/querynode/impl_utils.go b/internal/querynode/impl_utils.go index ee962d7159..337eb5f1f2 100644 --- a/internal/querynode/impl_utils.go +++ b/internal/querynode/impl_utils.go @@ -2,6 +2,8 @@ package querynode import ( "context" + "errors" + "fmt" "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/log" @@ -38,6 +40,13 @@ func (node *QueryNode) TransferLoad(ctx context.Context, req *querypb.LoadSegmen req.NeedTransfer = false err := shardCluster.LoadSegments(ctx, req) if err != nil { + if errors.Is(err, ErrInsufficientMemory) { + log.Warn("insufficient memory when shard cluster load segments", zap.Error(err)) + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad, + Reason: fmt.Sprintf("insufficient memory when shard cluster load segments, err:%s", err.Error()), + }, nil + } log.Warn("shard cluster failed to load segments", zap.Error(err)) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, diff --git a/internal/querynode/impl_utils_test.go b/internal/querynode/impl_utils_test.go index 2700bd8ab9..57c0593dc5 100644 --- a/internal/querynode/impl_utils_test.go +++ b/internal/querynode/impl_utils_test.go @@ -157,6 +157,39 @@ func (s *ImplUtilsSuite) TestTransferLoad() { s.NoError(err) s.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) }) + + s.Run("insufficient memory", func() { + cs, ok := s.querynode.ShardClusterService.getShardCluster(defaultChannelName) + s.Require().True(ok) + cs.nodes[100] = &shardNode{ + nodeID: 100, + nodeAddr: "test", + client: &mockShardQueryNode{ + loadSegmentsResults: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad, + Reason: "mock InsufficientMemoryToLoad", + }, + }, + } + + status, err := s.querynode.TransferLoad(ctx, &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + TargetID: s.querynode.session.ServerID, + }, + DstNodeID: 100, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: defaultSegmentID, + InsertChannel: defaultChannelName, + CollectionID: defaultCollectionID, + PartitionID: defaultPartitionID, + }, + }, + }) + + s.NoError(err) + s.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, status.GetErrorCode()) + }) } func (s *ImplUtilsSuite) TestTransferRelease() { diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 1cc7b19df2..536bfce745 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -952,7 +952,8 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad zap.Uint64("diskUsageAfterLoad", toMB(usedLocalSizeAfterLoad))) if memLoadingUsage > uint64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) { - return fmt.Errorf("load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, totalMem = %v MB, thresholdFactor = %f", + return fmt.Errorf("%w, load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %v MB, concurrency = %d, usedMemAfterLoad = %v MB, totalMem = %v MB, thresholdFactor = %f", + ErrInsufficientMemory, collectionID, toMB(maxSegmentSize), concurrency, diff --git a/internal/querynode/shard_cluster.go b/internal/querynode/shard_cluster.go index ffeadda3f0..44171828b4 100644 --- a/internal/querynode/shard_cluster.go +++ b/internal/querynode/shard_cluster.go @@ -643,6 +643,10 @@ func (sc *ShardCluster) LoadSegments(ctx context.Context, req *querypb.LoadSegme log.Warn("failed to dispatch load segment request", zap.Error(err)) return err } + if resp.GetErrorCode() == commonpb.ErrorCode_InsufficientMemoryToLoad { + log.Warn("insufficient memory when follower load segment", zap.String("reason", resp.GetReason())) + return fmt.Errorf("%w, reason:%s", ErrInsufficientMemory, resp.GetReason()) + } if resp.GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("follower load segment failed", zap.String("reason", resp.GetReason())) return fmt.Errorf("follower %d failed to load segment, reason %s", req.DstNodeID, resp.GetReason())