Add task unittest for query node (#7599)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/7616/head
bigsheeper 2021-09-09 10:10:00 +08:00 committed by GitHub
parent da405486ed
commit 5906551f41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 310 additions and 18 deletions

View File

@ -518,6 +518,19 @@ func genSimpleRowIDField() []IntPrimaryKey {
return ids
}
func genMsgStreamBaseMsg() msgstream.BaseMsg {
return msgstream.BaseMsg{
HashValues: []uint32{0},
}
}
func genCommonMsgBase(msgType commonpb.MsgType) *commonpb.MsgBase {
return &commonpb.MsgBase{
MsgType: msgType,
MsgID: rand.Int63(),
}
}
func genSimpleInsertMsg() (*msgstream.InsertMsg, error) {
rowData, err := genSimpleCommonBlob()
if err != nil {
@ -525,14 +538,9 @@ func genSimpleInsertMsg() (*msgstream.InsertMsg, error) {
}
return &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
BaseMsg: genMsgStreamBaseMsg(),
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
MsgID: rand.Int63(),
},
Base: genCommonMsgBase(commonpb.MsgType_Retrieve),
CollectionName: defaultCollectionName,
PartitionName: defaultPartitionName,
CollectionID: defaultCollectionID,
@ -861,10 +869,7 @@ func genSimpleSearchRequest() (*internalpb.SearchRequest, error) {
return nil, err
}
return &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: rand.Int63(), // TODO: random msgID?
},
Base: genCommonMsgBase(commonpb.MsgType_Search),
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
Dsl: simpleDSL,
@ -898,9 +903,7 @@ func genSimpleSearchMsg() (*msgstream.SearchMsg, error) {
return nil, err
}
return &msgstream.SearchMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
BaseMsg: genMsgStreamBaseMsg(),
SearchRequest: *req,
}, nil
}
@ -911,9 +914,7 @@ func genSimpleRetrieveMsg() (*msgstream.RetrieveMsg, error) {
return nil, err
}
return &msgstream.RetrieveMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
BaseMsg: genMsgStreamBaseMsg(),
RetrieveRequest: *req,
}, nil
}

View File

@ -12,10 +12,301 @@
package querynode
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
// TODO: add task ut
func TestTask_watchDmChannelsTask(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
genWatchDMChannelsRequest := func() *querypb.WatchDmChannelsRequest {
schema, _ := genSimpleSchema()
req := &querypb.WatchDmChannelsRequest{
Base: genCommonMsgBase(commonpb.MsgType_WatchDmChannels),
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
Schema: schema,
}
return req
}
t.Run("test timestamp", func(t *testing.T) {
task := watchDmChannelsTask{
req: genWatchDMChannelsRequest(),
}
timestamp := Timestamp(1000)
task.req.Base.Timestamp = timestamp
resT := task.Timestamp()
assert.Equal(t, timestamp, resT)
task.req.Base = nil
resT = task.Timestamp()
assert.Equal(t, Timestamp(0), resT)
})
t.Run("test OnEnqueue", func(t *testing.T) {
task := watchDmChannelsTask{
req: genWatchDMChannelsRequest(),
}
err := task.OnEnqueue()
assert.NoError(t, err)
task.req.Base = nil
err = task.OnEnqueue()
assert.NoError(t, err)
})
t.Run("test execute loadCollection", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := watchDmChannelsTask{
req: genWatchDMChannelsRequest(),
node: node,
}
task.req.Infos = []*datapb.VchannelInfo{
{
CollectionID: defaultCollectionID,
ChannelName: defaultVChannel,
},
}
task.req.PartitionID = 0
err = task.Execute(ctx)
assert.NoError(t, err)
})
t.Run("test execute loadPartition", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := watchDmChannelsTask{
req: genWatchDMChannelsRequest(),
node: node,
}
task.req.Infos = []*datapb.VchannelInfo{
{
CollectionID: defaultCollectionID,
ChannelName: defaultVChannel,
},
}
err = task.Execute(ctx)
assert.NoError(t, err)
})
// TODO: time consuming, reduce seek error time
t.Run("test execute seek error", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := watchDmChannelsTask{
req: genWatchDMChannelsRequest(),
node: node,
}
task.req.Infos = []*datapb.VchannelInfo{
{
CollectionID: defaultCollectionID,
ChannelName: defaultVChannel,
SeekPosition: &msgstream.MsgPosition{
ChannelName: defaultVChannel,
MsgID: []byte{1, 2, 3},
MsgGroup: defaultSubName,
Timestamp: 0,
},
},
}
err = task.Execute(ctx)
assert.Error(t, err)
})
}
func TestTask_loadSegmentsTask(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
genLoadSegmentsRequest := func() *querypb.LoadSegmentsRequest {
_, schema := genSimpleSchema()
req := &querypb.LoadSegmentsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments),
Schema: schema,
LoadCondition: querypb.TriggerCondition_grpcRequest,
}
return req
}
t.Run("test timestamp", func(t *testing.T) {
task := loadSegmentsTask{
req: genLoadSegmentsRequest(),
}
timestamp := Timestamp(1000)
task.req.Base.Timestamp = timestamp
resT := task.Timestamp()
assert.Equal(t, timestamp, resT)
task.req.Base = nil
resT = task.Timestamp()
assert.Equal(t, Timestamp(0), resT)
})
t.Run("test OnEnqueue", func(t *testing.T) {
task := loadSegmentsTask{
req: genLoadSegmentsRequest(),
}
err := task.OnEnqueue()
assert.NoError(t, err)
task.req.Base = nil
err = task.OnEnqueue()
assert.NoError(t, err)
})
t.Run("test execute grpc", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := loadSegmentsTask{
req: genLoadSegmentsRequest(),
node: node,
}
task.req.Infos = []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID + 1,
PartitionID: defaultPartitionID + 1,
CollectionID: defaultCollectionID + 1,
},
}
err = task.Execute(ctx)
assert.Error(t, err)
})
t.Run("test execute node down", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := loadSegmentsTask{
req: genLoadSegmentsRequest(),
node: node,
}
task.req.Infos = []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID + 1,
PartitionID: defaultPartitionID + 1,
CollectionID: defaultCollectionID + 1,
},
}
task.req.LoadCondition = querypb.TriggerCondition_nodeDown
err = task.Execute(ctx)
assert.Error(t, err)
})
t.Run("test execute load balance", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := loadSegmentsTask{
req: genLoadSegmentsRequest(),
node: node,
}
task.req.Infos = []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID + 1,
PartitionID: defaultPartitionID + 1,
CollectionID: defaultCollectionID + 1,
},
}
task.req.LoadCondition = querypb.TriggerCondition_loadBalance
err = task.Execute(ctx)
assert.Error(t, err)
})
}
func TestTask_releaseCollectionTask(t *testing.T) {
genReleaseCollectionRequest := func() *querypb.ReleaseCollectionRequest {
req := &querypb.ReleaseCollectionRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments),
CollectionID: defaultCollectionID,
}
return req
}
t.Run("test timestamp", func(t *testing.T) {
task := releaseCollectionTask{
req: genReleaseCollectionRequest(),
}
timestamp := Timestamp(1000)
task.req.Base.Timestamp = timestamp
resT := task.Timestamp()
assert.Equal(t, timestamp, resT)
task.req.Base = nil
resT = task.Timestamp()
assert.Equal(t, Timestamp(0), resT)
})
t.Run("test OnEnqueue", func(t *testing.T) {
task := releaseCollectionTask{
req: genReleaseCollectionRequest(),
}
err := task.OnEnqueue()
assert.NoError(t, err)
task.req.Base = nil
err = task.OnEnqueue()
assert.NoError(t, err)
})
}
func TestTask_releasePartitionTask(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
genReleasePartitionsRequest := func() *querypb.ReleasePartitionsRequest {
req := &querypb.ReleasePartitionsRequest{
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments),
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
return req
}
t.Run("test timestamp", func(t *testing.T) {
task := releasePartitionsTask{
req: genReleasePartitionsRequest(),
}
timestamp := Timestamp(1000)
task.req.Base.Timestamp = timestamp
resT := task.Timestamp()
assert.Equal(t, timestamp, resT)
task.req.Base = nil
resT = task.Timestamp()
assert.Equal(t, Timestamp(0), resT)
})
t.Run("test OnEnqueue", func(t *testing.T) {
task := releasePartitionsTask{
req: genReleasePartitionsRequest(),
}
err := task.OnEnqueue()
assert.NoError(t, err)
task.req.Base = nil
err = task.OnEnqueue()
assert.NoError(t, err)
})
t.Run("test execute", func(t *testing.T) {
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
task := releasePartitionsTask{
req: genReleasePartitionsRequest(),
node: node,
}
err = task.node.streaming.dataSyncService.addPartitionFlowGraph(defaultCollectionID,
defaultPartitionID,
[]Channel{defaultVChannel})
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
})
}