milvus/internal/datacoord/task_analyze_test.go

218 lines
6.3 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/session"
catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/taskcommon"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type analyzeTaskSuite struct {
suite.Suite
mt *meta
collID int64
partID int64
fieldID int64
taskID int64
}
func Test_analyzeTaskSuite(t *testing.T) {
suite.Run(t, new(analyzeTaskSuite))
}
func (s *analyzeTaskSuite) SetupSuite() {
s.collID = 1
s.partID = 2
s.fieldID = 3
s.taskID = 1000
// Mock analyze meta
catalog := catalogmocks.NewDataCoordCatalog(s.T())
analyzeMt := &analyzeMeta{
ctx: context.Background(),
catalog: catalog,
tasks: make(map[int64]*indexpb.AnalyzeTask),
}
// Add task to analyze meta
analyzeTask := &indexpb.AnalyzeTask{
CollectionID: s.collID,
PartitionID: s.partID,
FieldID: s.fieldID,
FieldName: "vector_field",
FieldType: schemapb.DataType_FloatVector,
TaskID: s.taskID,
Version: 1,
SegmentIDs: []int64{101, 102},
NodeID: 0,
State: indexpb.JobState_JobStateInit,
FailReason: "",
Dim: 128,
}
analyzeMt.tasks[s.taskID] = analyzeTask
s.mt = &meta{
analyzeMeta: analyzeMt,
}
}
func (s *analyzeTaskSuite) TestBasicTaskOperations() {
at := newAnalyzeTask(&indexpb.AnalyzeTask{
TaskID: s.taskID,
State: indexpb.JobState_JobStateInit,
}, s.mt)
s.Run("task type and state", func() {
s.Equal(taskcommon.Analyze, at.GetTaskType())
s.Equal(at.GetState(), at.GetTaskState())
s.Equal(Params.DataCoordCfg.AnalyzeTaskSlotUsage.GetAsInt64(), at.GetTaskSlot())
})
s.Run("time management", func() {
now := time.Now()
at.SetTaskTime(taskcommon.TimeQueue, now)
s.Equal(now, at.GetTaskTime(taskcommon.TimeQueue))
at.SetTaskTime(taskcommon.TimeStart, now)
s.Equal(now, at.GetTaskTime(taskcommon.TimeStart))
at.SetTaskTime(taskcommon.TimeEnd, now)
s.Equal(now, at.GetTaskTime(taskcommon.TimeEnd))
})
s.Run("state management", func() {
at.SetState(indexpb.JobState_JobStateInProgress, "test reason")
s.Equal(indexpb.JobState_JobStateInProgress, at.GetState())
s.Equal("test reason", at.GetFailReason())
})
}
func (s *analyzeTaskSuite) TestCreateTaskOnWorker() {
at := newAnalyzeTask(&indexpb.AnalyzeTask{
TaskID: s.taskID,
State: indexpb.JobState_JobStateInit,
}, s.mt)
s.Run("task not exist in meta", func() {
// Remove task from meta
originalTask := s.mt.analyzeMeta.tasks[s.taskID]
delete(s.mt.analyzeMeta.tasks, s.taskID)
at.CreateTaskOnWorker(1, session.NewMockCluster(s.T()))
s.Equal(indexpb.JobState_JobStateNone, at.GetState())
// Restore task
s.mt.analyzeMeta.tasks[s.taskID] = originalTask
})
s.Run("successful creation", func() {
cluster := session.NewMockCluster(s.T())
cluster.EXPECT().CreateAnalyze(mock.Anything, mock.Anything).Return(nil)
// Mock the UpdateVersion function
catalog := catalogmocks.NewDataCoordCatalog(s.T())
catalog.On("SaveAnalyzeTask", mock.Anything, mock.Anything).Return(nil)
s.mt.analyzeMeta.catalog = catalog
at.CreateTaskOnWorker(1, cluster)
s.Equal(indexpb.JobState_JobStateInProgress, at.GetState())
})
}
func (s *analyzeTaskSuite) TestQueryTaskOnWorker() {
at := newAnalyzeTask(&indexpb.AnalyzeTask{
TaskID: s.taskID,
NodeID: 1,
State: indexpb.JobState_JobStateInProgress,
}, s.mt)
s.Run("query failed", func() {
cluster := session.NewMockCluster(s.T())
cluster.EXPECT().QueryAnalyze(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock error"))
cluster.EXPECT().DropAnalyze(mock.Anything, mock.Anything).Return(nil)
at.QueryTaskOnWorker(cluster)
s.Equal(indexpb.JobState_JobStateInit, at.GetState())
})
s.Run("node not found", func() {
cluster := session.NewMockCluster(s.T())
cluster.EXPECT().QueryAnalyze(mock.Anything, mock.Anything).Return(nil, merr.ErrNodeNotFound)
cluster.EXPECT().DropAnalyze(mock.Anything, mock.Anything).Return(nil)
at.QueryTaskOnWorker(cluster)
s.Equal(indexpb.JobState_JobStateInit, at.GetState())
})
s.Run("task finished", func() {
cluster := session.NewMockCluster(s.T())
cluster.EXPECT().QueryAnalyze(mock.Anything, mock.Anything).Return(&workerpb.AnalyzeResults{
Results: []*workerpb.AnalyzeResult{{
TaskID: s.taskID,
State: indexpb.JobState_JobStateFinished,
FailReason: "",
}},
}, nil)
// Mock the FinishTask function
catalog := catalogmocks.NewDataCoordCatalog(s.T())
catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil)
s.mt.analyzeMeta.catalog = catalog
at.QueryTaskOnWorker(cluster)
s.Equal(indexpb.JobState_JobStateFinished, at.GetState())
})
}
func (s *analyzeTaskSuite) TestDropTaskOnWorker() {
at := newAnalyzeTask(&indexpb.AnalyzeTask{
TaskID: s.taskID,
NodeID: 1,
State: indexpb.JobState_JobStateInProgress,
}, s.mt)
s.Run("drop failed", func() {
cluster := session.NewMockCluster(s.T())
cluster.EXPECT().DropAnalyze(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error"))
// This should just log the error and return
at.DropTaskOnWorker(cluster)
})
s.Run("drop success", func() {
cluster := session.NewMockCluster(s.T())
cluster.EXPECT().DropAnalyze(mock.Anything, mock.Anything).Return(nil)
// This should complete successfully
at.DropTaskOnWorker(cluster)
})
}