mirror of https://github.com/milvus-io/milvus.git
Add task unittest for query node (#7599)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>pull/7616/head
parent
da405486ed
commit
5906551f41
|
@ -518,6 +518,19 @@ func genSimpleRowIDField() []IntPrimaryKey {
|
|||
return ids
|
||||
}
|
||||
|
||||
func genMsgStreamBaseMsg() msgstream.BaseMsg {
|
||||
return msgstream.BaseMsg{
|
||||
HashValues: []uint32{0},
|
||||
}
|
||||
}
|
||||
|
||||
func genCommonMsgBase(msgType commonpb.MsgType) *commonpb.MsgBase {
|
||||
return &commonpb.MsgBase{
|
||||
MsgType: msgType,
|
||||
MsgID: rand.Int63(),
|
||||
}
|
||||
}
|
||||
|
||||
func genSimpleInsertMsg() (*msgstream.InsertMsg, error) {
|
||||
rowData, err := genSimpleCommonBlob()
|
||||
if err != nil {
|
||||
|
@ -525,14 +538,9 @@ func genSimpleInsertMsg() (*msgstream.InsertMsg, error) {
|
|||
}
|
||||
|
||||
return &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
BaseMsg: genMsgStreamBaseMsg(),
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
MsgID: rand.Int63(),
|
||||
},
|
||||
Base: genCommonMsgBase(commonpb.MsgType_Retrieve),
|
||||
CollectionName: defaultCollectionName,
|
||||
PartitionName: defaultPartitionName,
|
||||
CollectionID: defaultCollectionID,
|
||||
|
@ -861,10 +869,7 @@ func genSimpleSearchRequest() (*internalpb.SearchRequest, error) {
|
|||
return nil, err
|
||||
}
|
||||
return &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
MsgID: rand.Int63(), // TODO: random msgID?
|
||||
},
|
||||
Base: genCommonMsgBase(commonpb.MsgType_Search),
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
Dsl: simpleDSL,
|
||||
|
@ -898,9 +903,7 @@ func genSimpleSearchMsg() (*msgstream.SearchMsg, error) {
|
|||
return nil, err
|
||||
}
|
||||
return &msgstream.SearchMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
BaseMsg: genMsgStreamBaseMsg(),
|
||||
SearchRequest: *req,
|
||||
}, nil
|
||||
}
|
||||
|
@ -911,9 +914,7 @@ func genSimpleRetrieveMsg() (*msgstream.RetrieveMsg, error) {
|
|||
return nil, err
|
||||
}
|
||||
return &msgstream.RetrieveMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
BaseMsg: genMsgStreamBaseMsg(),
|
||||
RetrieveRequest: *req,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -12,10 +12,301 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
// TODO: add task ut
|
||||
func TestTask_watchDmChannelsTask(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
genWatchDMChannelsRequest := func() *querypb.WatchDmChannelsRequest {
|
||||
schema, _ := genSimpleSchema()
|
||||
req := &querypb.WatchDmChannelsRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_WatchDmChannels),
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionID: defaultPartitionID,
|
||||
Schema: schema,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
t.Run("test timestamp", func(t *testing.T) {
|
||||
task := watchDmChannelsTask{
|
||||
req: genWatchDMChannelsRequest(),
|
||||
}
|
||||
timestamp := Timestamp(1000)
|
||||
task.req.Base.Timestamp = timestamp
|
||||
resT := task.Timestamp()
|
||||
assert.Equal(t, timestamp, resT)
|
||||
task.req.Base = nil
|
||||
resT = task.Timestamp()
|
||||
assert.Equal(t, Timestamp(0), resT)
|
||||
})
|
||||
|
||||
t.Run("test OnEnqueue", func(t *testing.T) {
|
||||
task := watchDmChannelsTask{
|
||||
req: genWatchDMChannelsRequest(),
|
||||
}
|
||||
err := task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
task.req.Base = nil
|
||||
err = task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute loadCollection", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := watchDmChannelsTask{
|
||||
req: genWatchDMChannelsRequest(),
|
||||
node: node,
|
||||
}
|
||||
task.req.Infos = []*datapb.VchannelInfo{
|
||||
{
|
||||
CollectionID: defaultCollectionID,
|
||||
ChannelName: defaultVChannel,
|
||||
},
|
||||
}
|
||||
task.req.PartitionID = 0
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute loadPartition", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := watchDmChannelsTask{
|
||||
req: genWatchDMChannelsRequest(),
|
||||
node: node,
|
||||
}
|
||||
task.req.Infos = []*datapb.VchannelInfo{
|
||||
{
|
||||
CollectionID: defaultCollectionID,
|
||||
ChannelName: defaultVChannel,
|
||||
},
|
||||
}
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
// TODO: time consuming, reduce seek error time
|
||||
t.Run("test execute seek error", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := watchDmChannelsTask{
|
||||
req: genWatchDMChannelsRequest(),
|
||||
node: node,
|
||||
}
|
||||
task.req.Infos = []*datapb.VchannelInfo{
|
||||
{
|
||||
CollectionID: defaultCollectionID,
|
||||
ChannelName: defaultVChannel,
|
||||
SeekPosition: &msgstream.MsgPosition{
|
||||
ChannelName: defaultVChannel,
|
||||
MsgID: []byte{1, 2, 3},
|
||||
MsgGroup: defaultSubName,
|
||||
Timestamp: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTask_loadSegmentsTask(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
genLoadSegmentsRequest := func() *querypb.LoadSegmentsRequest {
|
||||
_, schema := genSimpleSchema()
|
||||
req := &querypb.LoadSegmentsRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments),
|
||||
Schema: schema,
|
||||
LoadCondition: querypb.TriggerCondition_grpcRequest,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
t.Run("test timestamp", func(t *testing.T) {
|
||||
task := loadSegmentsTask{
|
||||
req: genLoadSegmentsRequest(),
|
||||
}
|
||||
timestamp := Timestamp(1000)
|
||||
task.req.Base.Timestamp = timestamp
|
||||
resT := task.Timestamp()
|
||||
assert.Equal(t, timestamp, resT)
|
||||
task.req.Base = nil
|
||||
resT = task.Timestamp()
|
||||
assert.Equal(t, Timestamp(0), resT)
|
||||
})
|
||||
|
||||
t.Run("test OnEnqueue", func(t *testing.T) {
|
||||
task := loadSegmentsTask{
|
||||
req: genLoadSegmentsRequest(),
|
||||
}
|
||||
err := task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
task.req.Base = nil
|
||||
err = task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute grpc", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := loadSegmentsTask{
|
||||
req: genLoadSegmentsRequest(),
|
||||
node: node,
|
||||
}
|
||||
task.req.Infos = []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: defaultSegmentID + 1,
|
||||
PartitionID: defaultPartitionID + 1,
|
||||
CollectionID: defaultCollectionID + 1,
|
||||
},
|
||||
}
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute node down", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := loadSegmentsTask{
|
||||
req: genLoadSegmentsRequest(),
|
||||
node: node,
|
||||
}
|
||||
task.req.Infos = []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: defaultSegmentID + 1,
|
||||
PartitionID: defaultPartitionID + 1,
|
||||
CollectionID: defaultCollectionID + 1,
|
||||
},
|
||||
}
|
||||
task.req.LoadCondition = querypb.TriggerCondition_nodeDown
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute load balance", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := loadSegmentsTask{
|
||||
req: genLoadSegmentsRequest(),
|
||||
node: node,
|
||||
}
|
||||
task.req.Infos = []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: defaultSegmentID + 1,
|
||||
PartitionID: defaultPartitionID + 1,
|
||||
CollectionID: defaultCollectionID + 1,
|
||||
},
|
||||
}
|
||||
task.req.LoadCondition = querypb.TriggerCondition_loadBalance
|
||||
err = task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTask_releaseCollectionTask(t *testing.T) {
|
||||
genReleaseCollectionRequest := func() *querypb.ReleaseCollectionRequest {
|
||||
req := &querypb.ReleaseCollectionRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments),
|
||||
CollectionID: defaultCollectionID,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
t.Run("test timestamp", func(t *testing.T) {
|
||||
task := releaseCollectionTask{
|
||||
req: genReleaseCollectionRequest(),
|
||||
}
|
||||
timestamp := Timestamp(1000)
|
||||
task.req.Base.Timestamp = timestamp
|
||||
resT := task.Timestamp()
|
||||
assert.Equal(t, timestamp, resT)
|
||||
task.req.Base = nil
|
||||
resT = task.Timestamp()
|
||||
assert.Equal(t, Timestamp(0), resT)
|
||||
})
|
||||
|
||||
t.Run("test OnEnqueue", func(t *testing.T) {
|
||||
task := releaseCollectionTask{
|
||||
req: genReleaseCollectionRequest(),
|
||||
}
|
||||
err := task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
task.req.Base = nil
|
||||
err = task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTask_releasePartitionTask(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
genReleasePartitionsRequest := func() *querypb.ReleasePartitionsRequest {
|
||||
req := &querypb.ReleasePartitionsRequest{
|
||||
Base: genCommonMsgBase(commonpb.MsgType_LoadSegments),
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
t.Run("test timestamp", func(t *testing.T) {
|
||||
task := releasePartitionsTask{
|
||||
req: genReleasePartitionsRequest(),
|
||||
}
|
||||
timestamp := Timestamp(1000)
|
||||
task.req.Base.Timestamp = timestamp
|
||||
resT := task.Timestamp()
|
||||
assert.Equal(t, timestamp, resT)
|
||||
task.req.Base = nil
|
||||
resT = task.Timestamp()
|
||||
assert.Equal(t, Timestamp(0), resT)
|
||||
})
|
||||
|
||||
t.Run("test OnEnqueue", func(t *testing.T) {
|
||||
task := releasePartitionsTask{
|
||||
req: genReleasePartitionsRequest(),
|
||||
}
|
||||
err := task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
task.req.Base = nil
|
||||
err = task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test execute", func(t *testing.T) {
|
||||
node, err := genSimpleQueryNode(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := releasePartitionsTask{
|
||||
req: genReleasePartitionsRequest(),
|
||||
node: node,
|
||||
}
|
||||
err = task.node.streaming.dataSyncService.addPartitionFlowGraph(defaultCollectionID,
|
||||
defaultPartitionID,
|
||||
[]Channel{defaultVChannel})
|
||||
assert.NoError(t, err)
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue