// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package importv2 import ( "context" "io" "strconv" "strings" "sync" "testing" "time" "github.com/cockroachdb/errors" "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/importutilv2" "github.com/milvus-io/milvus/internal/util/testutil" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) type sampleRow struct { FieldString string `json:"pk,omitempty"` FieldInt64 int64 `json:"int64,omitempty"` FieldFloatVector []float32 `json:"vec,omitempty"` } type sampleContent struct { Rows []sampleRow `json:"rows,omitempty"` } type mockReader struct { io.Reader io.Closer io.ReaderAt io.Seeker size int64 } func (mr *mockReader) Size() (int64, error) { return mr.size, nil } type SchedulerSuite struct { suite.Suite numRows int schema *schemapb.CollectionSchema cm storage.ChunkManager reader *importutilv2.MockReader syncMgr *syncmgr.MockSyncManager manager TaskManager scheduler *scheduler } func (s *SchedulerSuite) SetupSuite() { paramtable.Init() } func (s *SchedulerSuite) SetupTest() { s.numRows = 100 s.schema = &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { FieldID: 100, Name: "pk", IsPrimaryKey: true, DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: common.MaxLengthKey, Value: "128"}, }, }, { FieldID: 101, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ { Key: common.DimKey, Value: "4", }, }, }, { FieldID: 102, Name: "int64", DataType: schemapb.DataType_Int64, }, }, } s.manager = NewTaskManager() s.syncMgr = syncmgr.NewMockSyncManager(s.T()) s.scheduler = NewScheduler(s.manager).(*scheduler) } func (s *SchedulerSuite) TearDownTest() { s.scheduler.Close() } func (s *SchedulerSuite) TestScheduler_Slots() { preimportReq := &datapb.PreImportRequest{ JobID: 1, TaskID: 2, CollectionID: 3, PartitionIDs: []int64{4}, Vchannels: []string{"ch-0"}, Schema: s.schema, ImportFiles: []*internalpb.ImportFile{{Paths: []string{"dummy.json"}}}, TaskSlot: 10, } preimportTask := NewPreImportTask(preimportReq, s.manager, s.cm) s.manager.Add(preimportTask) slots := s.scheduler.Slots() s.Equal(int64(10), slots) } func (s *SchedulerSuite) TestScheduler_Start_Preimport() { content := &sampleContent{ Rows: make([]sampleRow, 0), } for i := 0; i < 10; i++ { row := sampleRow{ FieldString: "No." + strconv.FormatInt(int64(i), 10), FieldInt64: int64(99999999999999999 + i), FieldFloatVector: []float32{float32(i) + 0.1, float32(i) + 0.2, float32(i) + 0.3, float32(i) + 0.4}, } content.Rows = append(content.Rows, row) } bytes, err := json.Marshal(content) s.NoError(err) cm := mocks.NewChunkManager(s.T()) ioReader := strings.NewReader(string(bytes)) cm.EXPECT().Size(mock.Anything, mock.Anything).Return(1024, nil) cm.EXPECT().Reader(mock.Anything, mock.Anything).Return(&mockReader{Reader: ioReader, Closer: io.NopCloser(ioReader)}, nil) s.cm = cm preimportReq := &datapb.PreImportRequest{ JobID: 1, TaskID: 2, CollectionID: 3, PartitionIDs: []int64{4}, Vchannels: []string{"ch-0"}, Schema: s.schema, ImportFiles: []*internalpb.ImportFile{{Paths: []string{"dummy.json"}}}, } preimportTask := NewPreImportTask(preimportReq, s.manager, s.cm) s.manager.Add(preimportTask) go s.scheduler.Start() defer s.scheduler.Close() s.Eventually(func() bool { return s.manager.Get(preimportTask.GetTaskID()).GetState() == datapb.ImportTaskStateV2_Completed }, 10*time.Second, 100*time.Millisecond) } func (s *SchedulerSuite) TestScheduler_Start_Preimport_Failed() { content := &sampleContent{ Rows: make([]sampleRow, 0), } for i := 0; i < 10; i++ { var row sampleRow if i == 0 { // make rows not consistent row = sampleRow{ FieldString: "No." + strconv.FormatInt(int64(i), 10), FieldFloatVector: []float32{float32(i) + 0.1, float32(i) + 0.2, float32(i) + 0.3, float32(i) + 0.4}, } } else { row = sampleRow{ FieldString: "No." + strconv.FormatInt(int64(i), 10), FieldInt64: int64(99999999999999999 + i), FieldFloatVector: []float32{float32(i) + 0.1, float32(i) + 0.2, float32(i) + 0.3, float32(i) + 0.4}, } } content.Rows = append(content.Rows, row) } bytes, err := json.Marshal(content) s.NoError(err) cm := mocks.NewChunkManager(s.T()) ioReader := strings.NewReader(string(bytes)) cm.EXPECT().Size(mock.Anything, mock.Anything).Return(1024, nil) cm.EXPECT().Reader(mock.Anything, mock.Anything).Return(&mockReader{Reader: ioReader, Closer: io.NopCloser(ioReader)}, nil) s.cm = cm preimportReq := &datapb.PreImportRequest{ JobID: 1, TaskID: 2, CollectionID: 3, PartitionIDs: []int64{4}, Vchannels: []string{"ch-0"}, Schema: s.schema, ImportFiles: []*internalpb.ImportFile{{Paths: []string{"dummy.json"}}}, } preimportTask := NewPreImportTask(preimportReq, s.manager, s.cm) s.manager.Add(preimportTask) go s.scheduler.Start() defer s.scheduler.Close() s.Eventually(func() bool { return s.manager.Get(preimportTask.GetTaskID()).GetState() == datapb.ImportTaskStateV2_Failed }, 10*time.Second, 100*time.Millisecond) } func (s *SchedulerSuite) TestScheduler_Start_Import() { content := &sampleContent{ Rows: make([]sampleRow, 0), } for i := 0; i < 10; i++ { row := sampleRow{ FieldString: "No." + strconv.FormatInt(int64(i), 10), FieldInt64: int64(99999999999999999 + i), FieldFloatVector: []float32{float32(i) + 0.1, float32(i) + 0.2, float32(i) + 0.3, float32(i) + 0.4}, } content.Rows = append(content.Rows, row) } bytes, err := json.Marshal(content) s.NoError(err) cm := mocks.NewChunkManager(s.T()) ioReader := strings.NewReader(string(bytes)) cm.EXPECT().Reader(mock.Anything, mock.Anything).Return(&mockReader{Reader: ioReader, Closer: io.NopCloser(ioReader)}, nil) s.cm = cm s.syncMgr.EXPECT().SyncDataWithChunkManager(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, cm storage.ChunkManager, callbacks ...func(error) error) (*conc.Future[struct{}], error) { future := conc.Go(func() (struct{}, error) { return struct{}{}, nil }) return future, nil }) importReq := &datapb.ImportRequest{ JobID: 10, TaskID: 11, CollectionID: 12, PartitionIDs: []int64{13}, Vchannels: []string{"v0"}, Schema: s.schema, Files: []*internalpb.ImportFile{ { Paths: []string{"dummy.json"}, }, }, Ts: 1000, IDRange: &datapb.IDRange{ Begin: 0, End: int64(s.numRows), }, RequestSegments: []*datapb.ImportRequestSegment{ { SegmentID: 14, PartitionID: 13, Vchannel: "v0", }, }, } importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm) s.manager.Add(importTask) go s.scheduler.Start() defer s.scheduler.Close() s.Eventually(func() bool { return s.manager.Get(importTask.GetTaskID()).GetState() == datapb.ImportTaskStateV2_Completed }, 10*time.Second, 100*time.Millisecond) } func (s *SchedulerSuite) TestScheduler_Start_Import_Failed() { content := &sampleContent{ Rows: make([]sampleRow, 0), } for i := 0; i < 10; i++ { row := sampleRow{ FieldString: "No." + strconv.FormatInt(int64(i), 10), FieldInt64: int64(99999999999999999 + i), FieldFloatVector: []float32{float32(i) + 0.1, float32(i) + 0.2, float32(i) + 0.3, float32(i) + 0.4}, } content.Rows = append(content.Rows, row) } bytes, err := json.Marshal(content) s.NoError(err) cm := mocks.NewChunkManager(s.T()) ioReader := strings.NewReader(string(bytes)) cm.EXPECT().Reader(mock.Anything, mock.Anything).Return(&mockReader{Reader: ioReader, Closer: io.NopCloser(ioReader)}, nil) s.cm = cm s.syncMgr.EXPECT().SyncDataWithChunkManager(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, cm storage.ChunkManager, callbacks ...func(error) error) (*conc.Future[struct{}], error) { future := conc.Go(func() (struct{}, error) { return struct{}{}, errors.New("mock err") }) return future, nil }) importReq := &datapb.ImportRequest{ JobID: 10, TaskID: 11, CollectionID: 12, PartitionIDs: []int64{13}, Vchannels: []string{"v0"}, Schema: s.schema, Files: []*internalpb.ImportFile{ { Paths: []string{"dummy.json"}, }, }, Ts: 1000, IDRange: &datapb.IDRange{ Begin: 0, End: int64(s.numRows), }, RequestSegments: []*datapb.ImportRequestSegment{ { SegmentID: 14, PartitionID: 13, Vchannel: "v0", }, }, } importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm) s.manager.Add(importTask) go s.scheduler.Start() defer s.scheduler.Close() s.Eventually(func() bool { return s.manager.Get(importTask.GetTaskID()).GetState() == datapb.ImportTaskStateV2_Failed }, 10*time.Second, 100*time.Millisecond) } func (s *SchedulerSuite) TestScheduler_ReadFileStat() { importFile := &internalpb.ImportFile{ Paths: []string{"dummy.json"}, } var once sync.Once data, err := testutil.CreateInsertData(s.schema, s.numRows) s.NoError(err) s.reader = importutilv2.NewMockReader(s.T()) s.reader.EXPECT().Size().Return(1024, nil) s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) { var res *storage.InsertData once.Do(func() { res = data }) if res != nil { return res, nil } return nil, io.EOF }) preimportReq := &datapb.PreImportRequest{ JobID: 1, TaskID: 2, CollectionID: 3, PartitionIDs: []int64{4}, Vchannels: []string{"ch-0"}, Schema: s.schema, ImportFiles: []*internalpb.ImportFile{importFile}, } preimportTask := NewPreImportTask(preimportReq, s.manager, s.cm) s.manager.Add(preimportTask) err = preimportTask.(*PreImportTask).readFileStat(s.reader, 0) s.NoError(err) } func (s *SchedulerSuite) TestScheduler_ImportFile() { s.syncMgr.EXPECT().SyncDataWithChunkManager(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, cm storage.ChunkManager, callbacks ...func(error) error) (*conc.Future[struct{}], error) { future := conc.Go(func() (struct{}, error) { return struct{}{}, nil }) return future, nil }) var once sync.Once data, err := testutil.CreateInsertData(s.schema, s.numRows) s.NoError(err) s.reader = importutilv2.NewMockReader(s.T()) s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) { var res *storage.InsertData once.Do(func() { res = data }) if res != nil { return res, nil } return nil, io.EOF }) importReq := &datapb.ImportRequest{ JobID: 10, TaskID: 11, CollectionID: 12, PartitionIDs: []int64{13}, Vchannels: []string{"v0"}, Schema: s.schema, Files: []*internalpb.ImportFile{ { Paths: []string{"dummy.json"}, }, }, Ts: 1000, IDRange: &datapb.IDRange{ Begin: 0, End: int64(s.numRows), }, RequestSegments: []*datapb.ImportRequestSegment{ { SegmentID: 14, PartitionID: 13, Vchannel: "v0", }, }, } importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm) s.manager.Add(importTask) err = importTask.(*ImportTask).importFile(s.reader) s.NoError(err) } func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { paramtable.Init() paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string { return map[string]string{ "mock.apikey": "mock", } } s.syncMgr.EXPECT().SyncDataWithChunkManager(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, cm storage.ChunkManager, callbacks ...func(error) error) (*conc.Future[struct{}], error) { future := conc.Go(func() (struct{}, error) { return struct{}{}, nil }) return future, nil }) ts := function.CreateOpenAIEmbeddingServer() defer ts.Close() paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string { return map[string]string{ "openai.url": ts.URL, } } schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { FieldID: 100, Name: "pk", IsPrimaryKey: true, DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: common.MaxLengthKey, Value: "128"}, }, }, { FieldID: 101, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ { Key: common.DimKey, Value: "4", }, }, }, { FieldID: 102, Name: "int64", DataType: schemapb.DataType_Int64, }, }, Functions: []*schemapb.FunctionSchema{ { Name: "test", Type: schemapb.FunctionType_TextEmbedding, InputFieldIds: []int64{100}, InputFieldNames: []string{"text"}, OutputFieldIds: []int64{101}, OutputFieldNames: []string{"vec"}, Params: []*commonpb.KeyValuePair{ {Key: "provider", Value: "openai"}, {Key: "model_name", Value: "text-embedding-ada-002"}, {Key: "credential", Value: "mock"}, {Key: "dim", Value: "4"}, }, }, }, } var once sync.Once data, err := testutil.CreateInsertData(schema, s.numRows) s.NoError(err) s.reader = importutilv2.NewMockReader(s.T()) s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) { var res *storage.InsertData once.Do(func() { res = data }) if res != nil { return res, nil } return nil, io.EOF }) importReq := &datapb.ImportRequest{ JobID: 10, TaskID: 11, CollectionID: 12, PartitionIDs: []int64{13}, Vchannels: []string{"v0"}, Schema: schema, Files: []*internalpb.ImportFile{ { Paths: []string{"dummy.json"}, }, }, Ts: 1000, IDRange: &datapb.IDRange{ Begin: 0, End: int64(s.numRows), }, RequestSegments: []*datapb.ImportRequestSegment{ { SegmentID: 14, PartitionID: 13, Vchannel: "v0", }, }, } importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm) s.manager.Add(importTask) err = importTask.(*ImportTask).importFile(s.reader) s.NoError(err) } // TestScheduler_ScheduleTasks tests the scheduleTasks method with various scenarios func (s *SchedulerSuite) TestScheduler_ScheduleTasks() { // Memory limit exceeded - some tasks should be skipped s.Run("MemoryLimitExceeded", func() { manager := NewMockTaskManager(s.T()) s.scheduler.manager = manager // Add tasks that exceed memory limit tasks := make(map[int64]Task, 0) for i := 0; i < 5; i++ { t := NewMockTask(s.T()) t.EXPECT().GetTaskID().Return(int64(i)) t.EXPECT().GetBufferSize().Return(int64(300)) if i < 3 { // Only first 3 tasks should be allocated (900 total) t.EXPECT().Execute().Return([]*conc.Future[any]{}) } tasks[t.GetTaskID()] = t } manager.EXPECT().GetBy(mock.Anything).Return(lo.Values(tasks)) manager.EXPECT().Update(mock.Anything, mock.Anything).Return() memAllocator := NewMemoryAllocator(1000 / 0.2) s.scheduler.memoryAllocator = memAllocator s.scheduler.scheduleTasks() s.Equal(int64(0), memAllocator.(*memoryAllocator).usedMemory) }) // Task execution failure - memory should be released s.Run("TaskExecutionFailure", func() { manager := NewMockTaskManager(s.T()) s.scheduler.manager = manager tasks := make(map[int64]Task, 0) // Create a task that will fail execution failedTask := NewMockTask(s.T()) failedTask.EXPECT().GetTaskID().Return(int64(1)) failedTask.EXPECT().GetBufferSize().Return(int64(256)) // Create a future that will fail failedFuture := conc.Go(func() (any, error) { return nil, errors.New("mock execution error") }) failedTask.EXPECT().Execute().Return([]*conc.Future[any]{failedFuture}) tasks[failedTask.GetTaskID()] = failedTask // Create a successful task successTask := NewMockTask(s.T()) successTask.EXPECT().GetTaskID().Return(int64(2)) successTask.EXPECT().GetBufferSize().Return(int64(128)) successTask.EXPECT().Execute().Return([]*conc.Future[any]{}) tasks[successTask.GetTaskID()] = successTask manager.EXPECT().GetBy(mock.Anything).Return(lo.Values(tasks)) manager.EXPECT().Update(mock.Anything, mock.Anything).Return() memAllocator := NewMemoryAllocator(512 * 5) s.scheduler.memoryAllocator = memAllocator s.scheduler.scheduleTasks() s.Equal(int64(0), memAllocator.(*memoryAllocator).usedMemory) }) // Empty task list s.Run("EmptyTaskList", func() { manager := NewMockTaskManager(s.T()) s.scheduler.manager = manager memAllocator := NewMemoryAllocator(1024) s.scheduler.memoryAllocator = memAllocator manager.EXPECT().GetBy(mock.Anything).Return(nil) // Should not panic or error s.NotPanics(func() { s.scheduler.scheduleTasks() }) s.Equal(int64(0), memAllocator.(*memoryAllocator).usedMemory) }) } func TestScheduler(t *testing.T) { suite.Run(t, new(SchedulerSuite)) }