milvus/internal/datacoord/import_task_preimport_test.go

355 lines
12 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/datacoord/session"
"github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
taskcommon "github.com/milvus-io/milvus/pkg/v2/taskcommon"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
)
func TestPreImportTask_GetTaskType(t *testing.T) {
task := &preImportTask{}
assert.Equal(t, taskcommon.PreImport, task.GetTaskType())
}
func TestPreImportTask_TaskTime(t *testing.T) {
task := &preImportTask{
times: taskcommon.NewTimes(),
}
startTime := time.Now()
endTime := time.Now()
queueTime := time.Now()
task.SetTaskTime(taskcommon.TimeStart, startTime)
task.SetTaskTime(taskcommon.TimeEnd, endTime)
task.SetTaskTime(taskcommon.TimeQueue, queueTime)
assert.Equal(t, task.GetTaskTime(taskcommon.TimeStart), startTime)
assert.Equal(t, task.GetTaskTime(taskcommon.TimeEnd), endTime)
assert.Equal(t, task.GetTaskTime(taskcommon.TimeQueue), queueTime)
}
func TestPreImportTask_CreateTaskOnWorker(t *testing.T) {
t.Run("CreatePreImportTask rpc failed", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().SaveImportJob(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
var job ImportJob = &importJob{
ImportJob: &datapb.ImportJob{
JobID: 1,
},
}
err = im.AddJob(context.TODO(), job)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
State: datapb.ImportTaskStateV2_Pending,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().CreatePreImport(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test"))
task.CreateTaskOnWorker(1, cluster)
assert.Equal(t, datapb.ImportTaskStateV2_Pending, task.GetState())
})
t.Run("UpdateTask failed", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().SaveImportJob(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
var job ImportJob = &importJob{
ImportJob: &datapb.ImportJob{
JobID: 1,
},
}
err = im.AddJob(context.TODO(), job)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
State: datapb.ImportTaskStateV2_Pending,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().CreatePreImport(mock.Anything, mock.Anything, mock.Anything).Return(nil)
catalog = mocks.NewDataCoordCatalog(t)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(errors.New("mock err"))
task.importMeta.(*importMeta).catalog = catalog
task.CreateTaskOnWorker(1, cluster)
assert.Equal(t, datapb.ImportTaskStateV2_Pending, task.GetState())
})
t.Run("normal", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().SaveImportJob(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
var job ImportJob = &importJob{
ImportJob: &datapb.ImportJob{
JobID: 1,
},
}
err = im.AddJob(context.TODO(), job)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
State: datapb.ImportTaskStateV2_Pending,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().CreatePreImport(mock.Anything, mock.Anything, mock.Anything).Return(nil)
task.CreateTaskOnWorker(1, cluster)
assert.Equal(t, datapb.ImportTaskStateV2_InProgress, task.GetState())
})
}
func TestPreImportTask_QueryTaskOnWorker(t *testing.T) {
t.Run("QueryPreImport rpc failed", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
NodeID: 7,
State: datapb.ImportTaskStateV2_InProgress,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().QueryPreImport(mock.Anything, mock.Anything).Return(nil, errors.New("mock err"))
task.QueryTaskOnWorker(cluster)
assert.Equal(t, datapb.ImportTaskStateV2_Pending, task.GetState())
})
t.Run("preimport failed", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().SaveImportJob(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
var job ImportJob = &importJob{
ImportJob: &datapb.ImportJob{
JobID: 1,
State: internalpb.ImportJobState_PreImporting,
},
}
err = im.AddJob(context.TODO(), job)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
NodeID: 7,
State: datapb.ImportTaskStateV2_InProgress,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().QueryPreImport(mock.Anything, mock.Anything).Return(&datapb.QueryPreImportResponse{
State: datapb.ImportTaskStateV2_Failed,
}, nil)
task.QueryTaskOnWorker(cluster)
assert.Equal(t, datapb.ImportTaskStateV2_InProgress, task.GetState())
job = im.GetJob(context.TODO(), 1)
assert.Equal(t, internalpb.ImportJobState_Failed, job.GetState())
})
t.Run("normal", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
NodeID: 7,
State: datapb.ImportTaskStateV2_InProgress,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().QueryPreImport(mock.Anything, mock.Anything).Return(&datapb.QueryPreImportResponse{
State: datapb.ImportTaskStateV2_Completed,
}, nil)
task.QueryTaskOnWorker(cluster)
assert.Equal(t, datapb.ImportTaskStateV2_Completed, task.GetState())
})
}
func TestPreImportTask_DropTaskOnWorker(t *testing.T) {
t.Run("DropImport rpc failed", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
NodeID: 7,
State: datapb.ImportTaskStateV2_Completed,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().DropImport(mock.Anything, mock.Anything).Return(errors.New("mock err"))
task.DropTaskOnWorker(cluster)
assert.Equal(t, datapb.ImportTaskStateV2_Completed, task.GetState())
assert.Equal(t, int64(7), task.GetNodeID())
})
t.Run("normal", func(t *testing.T) {
catalog := mocks.NewDataCoordCatalog(t)
catalog.EXPECT().ListImportJobs(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListPreImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListImportTasks(mock.Anything).Return(nil, nil)
catalog.EXPECT().SavePreImportTask(mock.Anything, mock.Anything).Return(nil)
im, err := NewImportMeta(context.TODO(), catalog, nil, nil)
assert.NoError(t, err)
taskProto := &datapb.PreImportTask{
JobID: 1,
TaskID: 2,
CollectionID: 3,
NodeID: 7,
State: datapb.ImportTaskStateV2_Completed,
}
task := &preImportTask{
importMeta: im,
tr: timerecord.NewTimeRecorder(""),
}
task.task.Store(taskProto)
err = im.AddTask(context.TODO(), task)
assert.NoError(t, err)
cluster := session.NewMockCluster(t)
cluster.EXPECT().DropImport(mock.Anything, mock.Anything).Return(nil)
task.DropTaskOnWorker(cluster)
assert.Equal(t, datapb.ImportTaskStateV2_Completed, task.GetState())
assert.Equal(t, int64(NullNodeID), task.GetNodeID())
})
}