diff --git a/internal/datacoord/import_checker.go b/internal/datacoord/import_checker.go index 6b1fa54ffa..5b33593fa9 100644 --- a/internal/datacoord/import_checker.go +++ b/internal/datacoord/import_checker.go @@ -210,6 +210,16 @@ func (c *importChecker) checkPreImportingJob(job ImportJob) { return } + requestSize, err := CheckDiskQuota(job, c.meta, c.imeta) + if err != nil { + log.Warn("import failed, disk quota exceeded", zap.Int64("jobID", job.GetJobID()), zap.Error(err)) + err = c.imeta.UpdateJob(job.GetJobID(), UpdateJobState(internalpb.ImportJobState_Failed), UpdateJobReason(err.Error())) + if err != nil { + log.Warn("failed to update job state to Failed", zap.Int64("jobID", job.GetJobID()), zap.Error(err)) + } + return + } + groups := RegroupImportFiles(job, lacks) newTasks, err := NewImportTasks(groups, job, c.sm, c.alloc) if err != nil { @@ -224,7 +234,7 @@ func (c *importChecker) checkPreImportingJob(job ImportJob) { } log.Info("add new import task", WrapTaskLog(t)...) } - err = c.imeta.UpdateJob(job.GetJobID(), UpdateJobState(internalpb.ImportJobState_Importing)) + err = c.imeta.UpdateJob(job.GetJobID(), UpdateJobState(internalpb.ImportJobState_Importing), UpdateRequestedDiskSize(requestSize)) if err != nil { log.Warn("failed to update job state to Importing", zap.Int64("jobID", job.GetJobID()), zap.Error(err)) } diff --git a/internal/datacoord/import_job.go b/internal/datacoord/import_job.go index 2d82d27763..e7109f8c0d 100644 --- a/internal/datacoord/import_job.go +++ b/internal/datacoord/import_job.go @@ -42,6 +42,9 @@ func UpdateJobState(state internalpb.ImportJobState) UpdateJobAction { return func(job ImportJob) { job.(*importJob).ImportJob.State = state if state == internalpb.ImportJobState_Completed || state == internalpb.ImportJobState_Failed { + // releases requested disk resource + job.(*importJob).ImportJob.RequestedDiskSize = 0 + // set cleanup ts dur := Params.DataCoordCfg.ImportTaskRetention.GetAsDuration(time.Second) cleanupTs := tsoutil.ComposeTSByTime(time.Now().Add(dur), 0) job.(*importJob).ImportJob.CleanupTs = cleanupTs @@ -55,6 +58,12 @@ func UpdateJobReason(reason string) UpdateJobAction { } } +func UpdateRequestedDiskSize(requestSize int64) UpdateJobAction { + return func(job ImportJob) { + job.(*importJob).ImportJob.RequestedDiskSize = requestSize + } +} + func UpdateJobCompleteTime(completeTime string) UpdateJobAction { return func(job ImportJob) { job.(*importJob).ImportJob.CompleteTime = completeTime @@ -72,6 +81,7 @@ type ImportJob interface { GetCleanupTs() uint64 GetState() internalpb.ImportJobState GetReason() string + GetRequestedDiskSize() int64 GetCompleteTime() string GetFiles() []*internalpb.ImportFile GetOptions() []*commonpb.KeyValuePair diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index b3b06d3e4a..7cb6eb95b6 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -247,6 +247,45 @@ func RegroupImportFiles(job ImportJob, files []*datapb.ImportFileStats) [][]*dat return fileGroups } +func CheckDiskQuota(job ImportJob, meta *meta, imeta ImportMeta) (int64, error) { + if !Params.QuotaConfig.DiskProtectionEnabled.GetAsBool() { + return 0, nil + } + + var ( + requestedTotal int64 + requestedCollections = make(map[int64]int64) + ) + for _, j := range imeta.GetJobBy() { + requested := j.GetRequestedDiskSize() + requestedTotal += requested + requestedCollections[j.GetCollectionID()] += requested + } + + err := merr.WrapErrServiceQuotaExceeded("disk quota exceeded, please allocate more resources") + totalUsage, collectionsUsage := meta.GetCollectionBinlogSize() + + tasks := imeta.GetTaskBy(WithJob(job.GetJobID()), WithType(PreImportTaskType)) + files := make([]*datapb.ImportFileStats, 0) + for _, task := range tasks { + files = append(files, task.GetFileStats()...) + } + requestSize := lo.SumBy(files, func(file *datapb.ImportFileStats) int64 { + return file.GetTotalMemorySize() + }) + + totalDiskQuota := Params.QuotaConfig.DiskQuota.GetAsFloat() + if float64(totalUsage+requestedTotal+requestSize) > totalDiskQuota { + return 0, err + } + collectionDiskQuota := Params.QuotaConfig.DiskQuotaPerCollection.GetAsFloat() + colID := job.GetCollectionID() + if float64(collectionsUsage[colID]+requestedCollections[colID]+requestSize) > collectionDiskQuota { + return 0, err + } + return requestSize, nil +} + func getPendingProgress(jobID int64, imeta ImportMeta) float32 { tasks := imeta.GetTaskBy(WithJob(jobID), WithType(PreImportTaskType)) preImportingFiles := lo.SumBy(tasks, func(task ImportTask) int { diff --git a/internal/datacoord/import_util_test.go b/internal/datacoord/import_util_test.go index ecf058ac98..329bb79e11 100644 --- a/internal/datacoord/import_util_test.go +++ b/internal/datacoord/import_util_test.go @@ -27,12 +27,14 @@ import ( "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/metastore/mocks" mocks2 "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -216,6 +218,78 @@ func TestImportUtil_RegroupImportFiles(t *testing.T) { assert.Equal(t, fileNum, total) } +func TestImportUtil_CheckDiskQuota(t *testing.T) { + catalog := mocks.NewDataCoordCatalog(t) + catalog.EXPECT().ListImportJobs().Return(nil, nil) + catalog.EXPECT().ListImportTasks().Return(nil, nil) + catalog.EXPECT().ListPreImportTasks().Return(nil, nil) + catalog.EXPECT().SaveImportJob(mock.Anything).Return(nil) + catalog.EXPECT().SavePreImportTask(mock.Anything).Return(nil) + catalog.EXPECT().ListIndexes(mock.Anything).Return(nil, nil) + catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return(nil, nil) + catalog.EXPECT().ListSegments(mock.Anything).Return(nil, nil) + catalog.EXPECT().ListChannelCheckpoint(mock.Anything).Return(nil, nil) + catalog.EXPECT().AddSegment(mock.Anything, mock.Anything).Return(nil) + + imeta, err := NewImportMeta(catalog) + assert.NoError(t, err) + + meta, err := newMeta(context.TODO(), catalog, nil) + assert.NoError(t, err) + + job := &importJob{ + ImportJob: &datapb.ImportJob{ + JobID: 0, + CollectionID: 100, + }, + } + err = imeta.AddJob(job) + assert.NoError(t, err) + + pit := &preImportTask{ + PreImportTask: &datapb.PreImportTask{ + JobID: job.GetJobID(), + TaskID: 1, + FileStats: []*datapb.ImportFileStats{ + {TotalMemorySize: 1000 * 1024 * 1024}, + {TotalMemorySize: 2000 * 1024 * 1024}, + }, + }, + } + err = imeta.AddTask(pit) + assert.NoError(t, err) + + Params.Save(Params.QuotaConfig.DiskProtectionEnabled.Key, "false") + defer Params.Reset(Params.QuotaConfig.DiskProtectionEnabled.Key) + _, err = CheckDiskQuota(job, meta, imeta) + assert.NoError(t, err) + + segment := &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 100, State: commonpb.SegmentState_Flushed}, + size: *atomic.NewInt64(3000 * 1024 * 1024), + } + err = meta.AddSegment(context.Background(), segment) + assert.NoError(t, err) + + Params.Save(Params.QuotaConfig.DiskProtectionEnabled.Key, "true") + Params.Save(Params.QuotaConfig.DiskQuota.Key, "10000") + Params.Save(Params.QuotaConfig.DiskQuotaPerCollection.Key, "10000") + defer Params.Reset(Params.QuotaConfig.DiskQuota.Key) + defer Params.Reset(Params.QuotaConfig.DiskQuotaPerCollection.Key) + requestSize, err := CheckDiskQuota(job, meta, imeta) + assert.NoError(t, err) + assert.Equal(t, int64(3000*1024*1024), requestSize) + + Params.Save(Params.QuotaConfig.DiskQuota.Key, "5000") + _, err = CheckDiskQuota(job, meta, imeta) + assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) + + Params.Save(Params.QuotaConfig.DiskQuota.Key, "10000") + Params.Save(Params.QuotaConfig.DiskQuotaPerCollection.Key, "5000") + _, err = CheckDiskQuota(job, meta, imeta) + assert.True(t, errors.Is(err, merr.ErrServiceQuotaExceeded)) +} + func TestImportUtil_DropImportTask(t *testing.T) { cluster := NewMockCluster(t) cluster.EXPECT().DropImport(mock.Anything, mock.Anything).Return(nil) diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index 2a71e53ede..03ec4f13f2 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -261,7 +261,7 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) { var total int64 for _, segment := range segments { segmentSize := segment.getSegmentSize() - if isSegmentHealthy(segment) { + if isSegmentHealthy(segment) && !segment.GetIsImporting() { total += segmentSize collectionBinlogSize[segment.GetCollectionID()] += segmentSize metrics.DataCoordStoredBinlogSize.WithLabelValues( diff --git a/internal/datacoord/segment_manager.go b/internal/datacoord/segment_manager.go index 2bded38eca..a0b3e82b23 100644 --- a/internal/datacoord/segment_manager.go +++ b/internal/datacoord/segment_manager.go @@ -426,6 +426,8 @@ func (s *SegmentManager) AllocImportSegment(ctx context.Context, taskID int64, c log.Error("failed to add import segment", zap.Error(err)) return nil, err } + s.mu.Lock() + defer s.mu.Unlock() s.segments = append(s.segments, id) log.Info("add import segment done", zap.Int64("taskID", taskID), diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index ca4299d586..8f331a1bbe 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -865,11 +865,12 @@ message ImportJob { schema.CollectionSchema schema = 7; uint64 timeout_ts = 8; uint64 cleanup_ts = 9; - internal.ImportJobState state = 10; - string reason = 11; - string complete_time = 12; - repeated internal.ImportFile files = 13; - repeated common.KeyValuePair options = 14; + int64 requestedDiskSize = 10; + internal.ImportJobState state = 11; + string reason = 12; + string complete_time = 13; + repeated internal.ImportFile files = 14; + repeated common.KeyValuePair options = 15; } enum ImportTaskStateV2 { diff --git a/tests/integration/import/import_test.go b/tests/integration/import/import_test.go index c821dbf92d..e32c54aea1 100644 --- a/tests/integration/import/import_test.go +++ b/tests/integration/import/import_test.go @@ -37,19 +37,25 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/tests/integration" ) type BulkInsertSuite struct { integration.MiniClusterSuite + failed bool + failedReason string + pkType schemapb.DataType autoID bool fileType importutilv2.FileType } func (s *BulkInsertSuite) SetupTest() { + paramtable.Init() s.MiniClusterSuite.SetupTest() + s.failed = false s.fileType = importutilv2.Parquet s.pkType = schemapb.DataType_Int64 s.autoID = false @@ -124,6 +130,12 @@ func (s *BulkInsertSuite) run() { jobID := importResp.GetJobID() err = WaitForImportDone(ctx, c, jobID) + if s.failed { + s.T().Logf("expect failed import, err=%s", err) + s.Error(err) + s.Contains(err.Error(), s.failedReason) + return + } s.NoError(err) segments, err := c.MetaWatcher.ShowSegments() @@ -254,6 +266,20 @@ func (s *BulkInsertSuite) TestZeroRowCount() { s.Empty(segments) } +func (s *BulkInsertSuite) TestDiskQuotaExceeded() { + paramtable.Get().Save(paramtable.Get().QuotaConfig.DiskProtectionEnabled.Key, "true") + paramtable.Get().Save(paramtable.Get().QuotaConfig.DiskQuota.Key, "100") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.DiskProtectionEnabled.Key) + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.DiskQuota.Key) + s.failed = false + s.run() + + paramtable.Get().Save(paramtable.Get().QuotaConfig.DiskQuota.Key, "0.01") + s.failed = true + s.failedReason = "disk quota exceeded" + s.run() +} + func TestBulkInsert(t *testing.T) { suite.Run(t, new(BulkInsertSuite)) }