Add unittests for querynode (#7510)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
pull/7542/head
bigsheeper 2021-09-07 15:45:59 +08:00 committed by GitHub
parent 29756c6ce8
commit bcccd767ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 2253 additions and 448 deletions

View File

@ -79,10 +79,23 @@ func (c *Collection) removePartitionID(partitionID UniqueID) {
}
func (c *Collection) addVChannels(channels []Channel) {
log.Debug("add vChannels to collection",
zap.Any("channels", channels),
zap.Any("collectionID", c.ID()))
c.vChannels = append(c.vChannels, channels...)
OUTER:
for _, dstChan := range channels {
for _, srcChan := range c.vChannels {
if dstChan == srcChan {
log.Debug("vChannel has been existed in collection's vChannels",
zap.Any("collectionID", c.ID()),
zap.Any("vChannel", dstChan),
)
continue OUTER
}
}
log.Debug("add vChannel to collection",
zap.Any("collectionID", c.ID()),
zap.Any("vChannel", dstChan),
)
c.vChannels = append(c.vChannels, dstChan)
}
}
func (c *Collection) getVChannels() []Channel {
@ -90,10 +103,23 @@ func (c *Collection) getVChannels() []Channel {
}
func (c *Collection) addPChannels(channels []Channel) {
log.Debug("add pChannels to collection",
zap.Any("channels", channels),
zap.Any("collectionID", c.ID()))
c.pChannels = append(c.pChannels, channels...)
OUTER:
for _, dstChan := range channels {
for _, srcChan := range c.pChannels {
if dstChan == srcChan {
log.Debug("pChannel has been existed in collection's pChannels",
zap.Any("collectionID", c.ID()),
zap.Any("pChannel", dstChan),
)
continue OUTER
}
}
log.Debug("add pChannel to collection",
zap.Any("collectionID", c.ID()),
zap.Any("pChannel", dstChan),
)
c.pChannels = append(c.pChannels, dstChan)
}
}
func (c *Collection) getPChannels() []Channel {

View File

@ -87,7 +87,7 @@ func TestCollectionReplica_getPartitionNum(t *testing.T) {
}
partitionNum := node.historical.replica.getPartitionNum()
assert.Equal(t, partitionNum, len(partitionIDs)+1)
assert.Equal(t, partitionNum, len(partitionIDs))
err := node.Stop()
assert.NoError(t, err)
}

View File

@ -33,3 +33,80 @@ func TestCollection_deleteCollection(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
deleteCollection(collection)
}
func TestCollection_schema(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
schema := collection.Schema()
assert.Equal(t, collectionMeta.Schema.Name, schema.Name)
assert.Equal(t, len(collectionMeta.Schema.Fields), len(schema.Fields))
deleteCollection(collection)
}
func TestCollection_vChannel(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
collection.addVChannels([]string{defaultVChannel})
collection.addVChannels([]string{defaultVChannel})
collection.addVChannels([]string{"TestCollection_addVChannel_channel"})
channels := collection.getVChannels()
assert.Equal(t, 2, len(channels))
}
func TestCollection_pChannel(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
collection.addPChannels([]string{"TestCollection_addPChannel_channel-0"})
collection.addPChannels([]string{"TestCollection_addPChannel_channel-0"})
collection.addPChannels([]string{"TestCollection_addPChannel_channel-1"})
channels := collection.getPChannels()
assert.Equal(t, 2, len(channels))
}
func TestCollection_releaseTime(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
t0 := Timestamp(1000)
collection.setReleaseTime(t0)
t1 := collection.getReleaseTime()
assert.Equal(t, t0, t1)
}
func TestCollection_releasePartition(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
collection.addReleasedPartition(defaultPartitionID)
assert.Equal(t, 1, len(collection.releasedPartitions))
err := collection.checkReleasedPartitions([]UniqueID{defaultPartitionID})
assert.Error(t, err)
err = collection.checkReleasedPartitions([]UniqueID{UniqueID(1000)})
assert.NoError(t, err)
collection.deleteReleasedPartition(defaultPartitionID)
assert.Equal(t, 0, len(collection.releasedPartitions))
}
func TestCollection_loadType(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
collection.setLoadType(loadTypeCollection)
lt := collection.getLoadType()
assert.Equal(t, loadTypeCollection, lt)
collection.setLoadType(loadTypePartition)
lt = collection.getLoadType()
assert.Equal(t, loadTypePartition, lt)
}

View File

@ -26,8 +26,8 @@ import (
type loadType = int32
const (
loadTypeCollection = 0
loadTypePartition = 1
loadTypeCollection loadType = 0
loadTypePartition loadType = 1
)
type dataSyncService struct {

View File

@ -12,6 +12,7 @@
package querynode
import (
"context"
"encoding/binary"
"math"
"testing"
@ -122,5 +123,96 @@ func TestDataSyncService_Start(t *testing.T) {
assert.NoError(t, err)
<-node.queryNodeLoopCtx.Done()
node.Stop()
node.streaming.dataSyncService.close()
err = node.Stop()
assert.NoError(t, err)
}
func TestDataSyncService_collectionFlowGraphs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
streaming, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
fac, err := genFactory()
assert.NoError(t, err)
dataSyncService := newDataSyncService(ctx, streaming.replica, streaming.tSafeReplica, fac)
assert.NotNil(t, dataSyncService)
err = dataSyncService.addCollectionFlowGraph(defaultCollectionID, []Channel{defaultVChannel})
assert.NoError(t, err)
fg, err := dataSyncService.getCollectionFlowGraphs(defaultCollectionID, []Channel{defaultVChannel})
assert.NotNil(t, fg)
assert.NoError(t, err)
assert.Equal(t, 1, len(fg))
fg, err = dataSyncService.getCollectionFlowGraphs(UniqueID(1000), []Channel{defaultVChannel})
assert.Nil(t, fg)
assert.Error(t, err)
fg, err = dataSyncService.getCollectionFlowGraphs(defaultCollectionID, []Channel{"invalid-vChannel"})
assert.NotNil(t, fg)
assert.NoError(t, err)
assert.Equal(t, 0, len(fg))
fg, err = dataSyncService.getCollectionFlowGraphs(UniqueID(1000), []Channel{"invalid-vChannel"})
assert.Nil(t, fg)
assert.Error(t, err)
err = dataSyncService.startCollectionFlowGraph(defaultCollectionID, []Channel{defaultVChannel})
assert.NoError(t, err)
dataSyncService.removeCollectionFlowGraph(defaultCollectionID)
fg, err = dataSyncService.getCollectionFlowGraphs(defaultCollectionID, []Channel{defaultVChannel})
assert.Nil(t, fg)
assert.Error(t, err)
}
func TestDataSyncService_partitionFlowGraphs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
streaming, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
fac, err := genFactory()
assert.NoError(t, err)
dataSyncService := newDataSyncService(ctx, streaming.replica, streaming.tSafeReplica, fac)
assert.NotNil(t, dataSyncService)
err = dataSyncService.addPartitionFlowGraph(defaultPartitionID, defaultPartitionID, []Channel{defaultVChannel})
assert.NoError(t, err)
fg, err := dataSyncService.getPartitionFlowGraphs(defaultPartitionID, []Channel{defaultVChannel})
assert.NotNil(t, fg)
assert.NoError(t, err)
assert.Equal(t, 1, len(fg))
fg, err = dataSyncService.getPartitionFlowGraphs(UniqueID(1000), []Channel{defaultVChannel})
assert.Nil(t, fg)
assert.Error(t, err)
fg, err = dataSyncService.getPartitionFlowGraphs(defaultPartitionID, []Channel{"invalid-vChannel"})
assert.NotNil(t, fg)
assert.NoError(t, err)
assert.Equal(t, 0, len(fg))
fg, err = dataSyncService.getPartitionFlowGraphs(UniqueID(1000), []Channel{"invalid-vChannel"})
assert.Nil(t, fg)
assert.Error(t, err)
err = dataSyncService.startPartitionFlowGraph(defaultPartitionID, []Channel{defaultVChannel})
assert.NoError(t, err)
dataSyncService.removePartitionFlowGraph(defaultPartitionID)
fg, err = dataSyncService.getPartitionFlowGraphs(defaultPartitionID, []Channel{defaultVChannel})
assert.Nil(t, fg)
assert.Error(t, err)
}

View File

@ -1,198 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querynode
import (
"github.com/golang/protobuf/proto"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/flowgraph"
"github.com/milvus-io/milvus/internal/util/trace"
)
type ddNode struct {
baseNode
ddMsg *ddMsg
replica ReplicaInterface
}
func (ddNode *ddNode) Name() string {
return "ddNode"
}
func (ddNode *ddNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
//log.Debug("Do filterDmNode operation")
if len(in) != 1 {
log.Error("Invalid operate message input in ddNode", zap.Int("input length", len(in)))
// TODO: add error handling
}
msMsg, ok := in[0].(*MsgStreamMsg)
if !ok {
log.Warn("type assertion failed for MsgStreamMsg")
// TODO: add error handling
}
var spans []opentracing.Span
for _, msg := range msMsg.TsMessages() {
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
spans = append(spans, sp)
msg.SetTraceCtx(ctx)
}
var ddMsg = ddMsg{
collectionRecords: make(map[UniqueID][]metaOperateRecord),
partitionRecords: make(map[UniqueID][]metaOperateRecord),
timeRange: TimeRange{
timestampMin: msMsg.TimestampMin(),
timestampMax: msMsg.TimestampMax(),
},
}
ddNode.ddMsg = &ddMsg
gcRecord := gcRecord{
collections: make([]UniqueID, 0),
partitions: make([]partitionWithID, 0),
}
ddNode.ddMsg.gcRecord = &gcRecord
// sort tsMessages
//tsMessages := msMsg.TsMessages()
//sort.Slice(tsMessages,
// func(i, j int) bool {
// return tsMessages[i].BeginTs() < tsMessages[j].BeginTs()
// })
// do dd tasks
//for _, msg := range tsMessages {
// switch msg.Type() {
// case commonpb.MsgType_kCreateCollection:
// ddNode.createCollection(msg.(*msgstream.CreateCollectionMsg))
// case commonpb.MsgType_kDropCollection:
// ddNode.dropCollection(msg.(*msgstream.DropCollectionMsg))
// case commonpb.MsgType_kCreatePartition:
// ddNode.createPartition(msg.(*msgstream.CreatePartitionMsg))
// case commonpb.MsgType_kDropPartition:
// ddNode.dropPartition(msg.(*msgstream.DropPartitionMsg))
// default:
// log.Println("Non supporting message type:", msg.Type())
// }
//}
var res Msg = ddNode.ddMsg
for _, span := range spans {
span.Finish()
}
return []Msg{res}
}
func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) {
collectionID := msg.CollectionID
partitionID := msg.PartitionID
hasCollection := ddNode.replica.hasCollection(collectionID)
if hasCollection {
log.Debug("collection already exists", zap.Int64("collectionID", collectionID))
return
}
var schema schemapb.CollectionSchema
err := proto.Unmarshal(msg.Schema, &schema)
if err != nil {
log.Warn(err.Error())
return
}
// add collection
err = ddNode.replica.addCollection(collectionID, &schema)
if err != nil {
log.Warn(err.Error())
return
}
// add default partition
// TODO: allocate default partition id in master
err = ddNode.replica.addPartition(collectionID, partitionID)
if err != nil {
log.Warn(err.Error())
return
}
ddNode.ddMsg.collectionRecords[collectionID] = append(ddNode.ddMsg.collectionRecords[collectionID],
metaOperateRecord{
createOrDrop: true,
timestamp: msg.Base.Timestamp,
})
}
func (ddNode *ddNode) dropCollection(msg *msgstream.DropCollectionMsg) {
collectionID := msg.CollectionID
ddNode.ddMsg.collectionRecords[collectionID] = append(ddNode.ddMsg.collectionRecords[collectionID],
metaOperateRecord{
createOrDrop: false,
timestamp: msg.Base.Timestamp,
})
ddNode.ddMsg.gcRecord.collections = append(ddNode.ddMsg.gcRecord.collections, collectionID)
}
func (ddNode *ddNode) createPartition(msg *msgstream.CreatePartitionMsg) {
collectionID := msg.CollectionID
partitionID := msg.PartitionID
err := ddNode.replica.addPartition(collectionID, partitionID)
if err != nil {
log.Warn(err.Error())
return
}
ddNode.ddMsg.partitionRecords[partitionID] = append(ddNode.ddMsg.partitionRecords[partitionID],
metaOperateRecord{
createOrDrop: true,
timestamp: msg.Base.Timestamp,
})
}
func (ddNode *ddNode) dropPartition(msg *msgstream.DropPartitionMsg) {
collectionID := msg.CollectionID
partitionID := msg.PartitionID
ddNode.ddMsg.partitionRecords[partitionID] = append(ddNode.ddMsg.partitionRecords[partitionID],
metaOperateRecord{
createOrDrop: false,
timestamp: msg.Base.Timestamp,
})
ddNode.ddMsg.gcRecord.partitions = append(ddNode.ddMsg.gcRecord.partitions, partitionWithID{
partitionID: partitionID,
collectionID: collectionID,
})
}
func newDDNode(replica ReplicaInterface) *ddNode {
maxQueueLength := Params.FlowGraphMaxQueueLength
maxParallelism := Params.FlowGraphMaxParallelism
baseNode := baseNode{}
baseNode.SetMaxQueueLength(maxQueueLength)
baseNode.SetMaxParallelism(maxParallelism)
return &ddNode{
baseNode: baseNode,
replica: replica,
}
}

View File

@ -185,6 +185,7 @@ func newFilteredDmNode(replica ReplicaInterface,
if loadType != loadTypeCollection && loadType != loadTypePartition {
err := errors.New("invalid flow graph type")
log.Warn(err.Error())
return nil
}
return &filterDmNode{

View File

@ -0,0 +1,199 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querynode
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/flowgraph"
)
func getFilterDMNode(ctx context.Context) (*filterDmNode, error) {
streaming, err := genSimpleStreaming(ctx)
if err != nil {
return nil, err
}
streaming.replica.initExcludedSegments(defaultCollectionID)
return newFilteredDmNode(streaming.replica, loadTypeCollection, defaultCollectionID, defaultPartitionID), nil
}
func TestFlowGraphFilterDmNode_filterDmNode(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
fg.Name()
}
func TestFlowGraphFilterDmNode_invalidLoadType(t *testing.T) {
const invalidLoadType = -1
fg := newFilteredDmNode(nil, invalidLoadType, defaultCollectionID, defaultPartitionID)
assert.Nil(t, fg)
}
func TestFlowGraphFilterDmNode_filterInvalidInsertMessage(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("valid test", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
res := fg.filterInvalidInsertMessage(msg)
assert.NotNil(t, res)
})
t.Run("test no collection", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
msg.CollectionID = UniqueID(1000)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test no partition", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
msg.PartitionID = UniqueID(1000)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
fg.loadType = loadTypePartition
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test not target collection", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
fg.collectionID = UniqueID(1000)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test not target partition", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
fg.loadType = loadTypePartition
fg.partitionID = UniqueID(1000)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test released partition", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
col, err := fg.replica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
col.addReleasedPartition(defaultPartitionID)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test no exclude segment", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
fg.replica.removeExcludedSegments(defaultCollectionID)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test segment is exclude segment", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
err = fg.replica.addExcludedSegments(defaultCollectionID, []*datapb.SegmentInfo{
{
ID: defaultSegmentID,
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
DmlPosition: &internalpb.MsgPosition{
Timestamp: Timestamp(1000),
},
},
})
assert.NoError(t, err)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test misaligned messages", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
msg.Timestamps = make([]Timestamp, 0)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
t.Run("test no data", func(t *testing.T) {
msg, err := genSimpleInsertMsg()
assert.NoError(t, err)
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
msg.Timestamps = make([]Timestamp, 0)
msg.RowIDs = make([]IntPrimaryKey, 0)
msg.RowData = make([]*commonpb.Blob, 0)
res := fg.filterInvalidInsertMessage(msg)
assert.Nil(t, res)
})
}
func TestFlowGraphFilterDmNode_Operate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
genFilterDMMsg := func() []flowgraph.Msg {
iMsg, err := genSimpleInsertMsg()
assert.NoError(t, err)
msg := flowgraph.GenerateMsgStreamMsg([]msgstream.TsMsg{iMsg}, 0, 1000, nil, nil)
return []flowgraph.Msg{msg}
}
t.Run("valid test", func(t *testing.T) {
msg := genFilterDMMsg()
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
res := fg.Operate(msg)
assert.NotNil(t, res)
})
t.Run("invalid input length", func(t *testing.T) {
msg := genFilterDMMsg()
fg, err := getFilterDMNode(ctx)
assert.NoError(t, err)
var m flowgraph.Msg
msg = append(msg, m)
res := fg.Operate(msg)
assert.NotNil(t, res)
})
}

View File

@ -1,79 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querynode
import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/flowgraph"
)
type gcNode struct {
baseNode
replica ReplicaInterface
}
func (gcNode *gcNode) Name() string {
return "gcNode"
}
func (gcNode *gcNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
//log.Debug("Do gcNode operation")
if len(in) != 1 {
log.Error("Invalid operate message input in gcNode", zap.Int("input length", len(in)))
// TODO: add error handling
}
_, ok := in[0].(*gcMsg)
if !ok {
log.Warn("type assertion failed for gcMsg")
// TODO: add error handling
}
// Use `releasePartition` and `releaseCollection`,
// because if we drop collections or partitions here, query service doesn't know this behavior,
// which would lead the wrong result of `showCollections` or `showPartition`
//// drop collections
//for _, collectionID := range gcMsg.gcRecord.collections {
// err := gcNode.replica.removeCollection(collectionID)
// if err != nil {
// log.Println(err)
// }
//}
//
//// drop partitions
//for _, partition := range gcMsg.gcRecord.partitions {
// err := gcNode.replica.removePartition(partition.partitionID)
// if err != nil {
// log.Println(err)
// }
//}
return nil
}
func newGCNode(replica ReplicaInterface) *gcNode {
maxQueueLength := Params.FlowGraphMaxQueueLength
maxParallelism := Params.FlowGraphMaxParallelism
baseNode := baseNode{}
baseNode.SetMaxQueueLength(maxQueueLength)
baseNode.SetMaxParallelism(maxParallelism)
return &gcNode{
baseNode: baseNode,
replica: replica,
}
}

View File

@ -1,63 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querynode
type key2SegNode struct {
baseNode
key2SegMsg key2SegMsg
}
func (ksNode *key2SegNode) Name() string {
return "ksNode"
}
func (ksNode *key2SegNode) Operate(in []*Msg) []*Msg {
return in
}
func newKey2SegNode() *key2SegNode {
maxQueueLength := Params.FlowGraphMaxQueueLength
maxParallelism := Params.FlowGraphMaxParallelism
baseNode := baseNode{}
baseNode.SetMaxQueueLength(maxQueueLength)
baseNode.SetMaxParallelism(maxParallelism)
return &key2SegNode{
baseNode: baseNode,
}
}
/************************************** util functions ***************************************/
// Function `GetSegmentByEntityId` should return entityIDs, timestamps and segmentIDs
//func (node *QueryNode) GetKey2Segments() (*[]int64, *[]uint64, *[]int64) {
// var entityIDs = make([]int64, 0)
// var timestamps = make([]uint64, 0)
// var segmentIDs = make([]int64, 0)
//
// var key2SegMsg = node.messageClient.Key2SegMsg
// for _, msg := range key2SegMsg {
// if msg.SegmentID == nil {
// segmentIDs = append(segmentIDs, -1)
// entityIDs = append(entityIDs, msg.Uid)
// timestamps = append(timestamps, msg.Timestamp)
// } else {
// for _, segmentID := range msg.SegmentID {
// segmentIDs = append(segmentIDs, segmentID)
// entityIDs = append(entityIDs, msg.Uid)
// timestamps = append(timestamps, msg.Timestamp)
// }
// }
// }
//
// return &entityIDs, &timestamps, &segmentIDs
//}

View File

@ -19,34 +19,12 @@ import (
type Msg = flowgraph.Msg
type MsgStreamMsg = flowgraph.MsgStreamMsg
type key2SegMsg struct {
tsMessages []msgstream.TsMsg
timeRange TimeRange
}
type ddMsg struct {
collectionRecords map[UniqueID][]metaOperateRecord
partitionRecords map[UniqueID][]metaOperateRecord
gcRecord *gcRecord
timeRange TimeRange
}
type metaOperateRecord struct {
createOrDrop bool // create: true, drop: false
timestamp Timestamp
}
type insertMsg struct {
insertMessages []*msgstream.InsertMsg
gcRecord *gcRecord
timeRange TimeRange
}
type deleteMsg struct {
deleteMessages []*msgstream.DeleteMsg
timeRange TimeRange
}
type serviceTimeMsg struct {
gcRecord *gcRecord
timeRange TimeRange
@ -86,22 +64,10 @@ type gcRecord struct {
partitions []partitionWithID
}
func (ksMsg *key2SegMsg) TimeTick() Timestamp {
return ksMsg.timeRange.timestampMax
}
func (suMsg *ddMsg) TimeTick() Timestamp {
return suMsg.timeRange.timestampMax
}
func (iMsg *insertMsg) TimeTick() Timestamp {
return iMsg.timeRange.timestampMax
}
func (dMsg *deleteMsg) TimeTick() Timestamp {
return dMsg.timeRange.timestampMax
}
func (stMsg *serviceTimeMsg) TimeTick() Timestamp {
return stMsg.timeRange.timestampMax
}

View File

@ -0,0 +1,263 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querynode
import (
"context"
"math/rand"
"testing"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
queryPb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/stretchr/testify/assert"
)
func TestImpl_GetComponentStates(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
rsp, err := node.GetComponentStates(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
rsp, err = node.GetComponentStates(ctx)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode)
}
func TestImpl_GetTimeTickChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
rsp, err := node.GetTimeTickChannel(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
}
func TestImpl_GetStatisticsChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
rsp, err := node.GetStatisticsChannel(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
}
func TestImpl_AddQueryChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.AddQueryChannelRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
NodeID: 0,
CollectionID: defaultCollectionID,
RequestChannelID: genQueryChannel(),
ResultChannelID: genQueryResultChannel(),
}
status, err := node.AddQueryChannel(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
status, err = node.AddQueryChannel(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}
func TestImpl_RemoveQueryChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
status, err := node.RemoveQueryChannel(ctx, nil)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
}
func TestImpl_WatchDmChannels(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
_, schema := genSimpleSchema()
req := &queryPb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
NodeID: 0,
CollectionID: defaultCollectionID,
PartitionID: defaultPartitionID,
Schema: schema,
}
status, err := node.WatchDmChannels(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
status, err = node.WatchDmChannels(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}
func TestImpl_LoadSegments(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
_, schema := genSimpleSchema()
req := &queryPb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
NodeID: 0,
Schema: schema,
LoadCondition: queryPb.TriggerCondition_grpcRequest,
}
status, err := node.LoadSegments(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
status, err = node.LoadSegments(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}
func TestImpl_ReleaseCollection(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
NodeID: 0,
CollectionID: defaultCollectionID,
}
status, err := node.ReleaseCollection(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
status, err = node.ReleaseCollection(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}
func TestImpl_ReleasePartitions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
NodeID: 0,
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
status, err := node.ReleasePartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
status, err = node.ReleasePartitions(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
}
func TestImpl_GetSegmentInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
SegmentIDs: []UniqueID{defaultSegmentID},
}
rsp, err := node.GetSegmentInfo(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
rsp, err = node.GetSegmentInfo(ctx, req)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode)
}
func TestImpl_isHealthy(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
isHealthy := node.isHealthy()
assert.True(t, isHealthy)
}
func TestImpl_GetMetrics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &milvuspb.GetMetricsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
}
_, err = node.GetMetrics(ctx, req)
assert.NoError(t, err)
node.UpdateStateCode(internalpb.StateCode_Abnormal)
_, err = node.GetMetrics(ctx, req)
assert.NoError(t, err)
}

View File

@ -287,6 +287,11 @@ func (loader *indexLoader) sendQueryNodeStats() error {
}
func (loader *indexLoader) setIndexInfo(collectionID UniqueID, segment *Segment, fieldID UniqueID) error {
if loader.indexCoord == nil || loader.rootCoord == nil {
return errors.New("null index coordinator client or root coordinator client, collectionID = " +
fmt.Sprintln(collectionID))
}
ctx := context.TODO()
req := &milvuspb.DescribeSegmentRequest{
Base: &commonpb.MsgBase{
@ -307,10 +312,6 @@ func (loader *indexLoader) setIndexInfo(collectionID UniqueID, segment *Segment,
return errors.New("there are no indexes on this segment")
}
if loader.indexCoord == nil {
return errors.New("null index coordinator client")
}
indexFilePathRequest := &indexpb.GetIndexFilePathsRequest{
IndexBuildIDs: []UniqueID{response.BuildID},
}

File diff suppressed because it is too large Load Diff

View File

@ -253,33 +253,6 @@ func (q *queryCollection) consumeQuery() {
func (q *queryCollection) loadBalance(msg *msgstream.LoadBalanceSegmentsMsg) {
//TODO:: get loadBalance info from etcd
//log.Debug("consume load balance message",
// zap.Int64("msgID", msg.ID()))
//nodeID := Params.QueryNodeID
//for _, info := range msg.Infos {
// segmentID := info.SegmentID
// if nodeID == info.SourceNodeID {
// err := s.historical.replica.removeSegment(segmentID)
// if err != nil {
// log.Warn("loadBalance failed when remove segment",
// zap.Error(err),
// zap.Any("segmentID", segmentID))
// }
// }
// if nodeID == info.DstNodeID {
// segment, err := s.historical.replica.getSegmentByID(segmentID)
// if err != nil {
// log.Warn("loadBalance failed when making segment on service",
// zap.Error(err),
// zap.Any("segmentID", segmentID))
// continue // not return, try to load balance all segment
// }
// segment.setOnService(true)
// }
//}
//log.Debug("load balance done",
// zap.Int64("msgID", msg.ID()),
// zap.Int("num of segment", len(msg.Infos)))
}
func (q *queryCollection) receiveQueryMsg(msg queryMsg) error {

View File

@ -1,11 +1,14 @@
package querynode
import (
"bytes"
"context"
"encoding/binary"
"errors"
"math"
"math/rand"
"testing"
"time"
"github.com/bits-and-blooms/bloom/v3"
"github.com/golang/protobuf/proto"
@ -16,8 +19,50 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*queryCollection, error) {
historical, err := genSimpleHistorical(ctx)
if err != nil {
return nil, err
}
streaming, err := genSimpleStreaming(ctx)
if err != nil {
return nil, err
}
fac, err := genFactory()
if err != nil {
return nil, err
}
localCM, err := genLocalChunkManager()
if err != nil {
return nil, err
}
remoteCM, err := genRemoteChunkManager(ctx)
if err != nil {
return nil, err
}
queryCollection := newQueryCollection(ctx, cancel,
defaultCollectionID,
historical,
streaming,
fac,
localCM,
remoteCM,
false)
if queryCollection == nil {
return nil, errors.New("nil simple query collection")
}
return queryCollection, nil
}
func TestQueryCollection_withoutVChannel(t *testing.T) {
m := map[string]interface{}{
"PulsarAddress": Params.PulsarAddress,
@ -179,3 +224,226 @@ func TestGetSegmentsByPKs(t *testing.T) {
_, err = getSegmentsByPKs([]int64{0, 1, 2, 3, 4}, nil)
assert.NotNil(t, err)
}
func TestQueryCollection_unsolvedMsg(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
qm, err := genSimpleSearchMsg()
assert.NoError(t, err)
queryCollection.addToUnsolvedMsg(qm)
res := queryCollection.popAllUnsolvedMsg()
assert.NotNil(t, res)
assert.Len(t, res, 1)
}
func TestQueryCollection_consumeQuery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
runConsumeQuery := func(msg msgstream.TsMsg) {
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
queryChannel := genQueryChannel()
queryCollection.queryMsgStream.AsConsumer([]Channel{queryChannel}, defaultSubName)
queryCollection.queryMsgStream.Start()
go queryCollection.consumeQuery()
producer, err := genQueryMsgStream(ctx)
assert.NoError(t, err)
producer.AsProducer([]Channel{queryChannel})
producer.Start()
msgPack := &msgstream.MsgPack{
BeginTs: 0,
EndTs: 10,
Msgs: []msgstream.TsMsg{msg},
}
err = producer.Produce(msgPack)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond)
}
t.Run("consume search", func(t *testing.T) {
msg, err := genSimpleSearchMsg()
assert.NoError(t, err)
runConsumeQuery(msg)
})
t.Run("consume retrieve", func(t *testing.T) {
msg, err := genSimpleRetrieveMsg()
assert.NoError(t, err)
runConsumeQuery(msg)
})
t.Run("consume load balance", func(t *testing.T) {
msg := &msgstream.LoadBalanceSegmentsMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
LoadBalanceSegmentsRequest: internalpb.LoadBalanceSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
MsgID: rand.Int63(), // TODO: random msgID?
},
SegmentIDs: []UniqueID{defaultSegmentID},
},
}
runConsumeQuery(msg)
})
t.Run("consume invalid msg", func(t *testing.T) {
msg, err := genSimpleRetrieveMsg()
assert.NoError(t, err)
msg.Base.MsgType = commonpb.MsgType_CreateCollection
runConsumeQuery(msg)
})
}
func TestResultHandlerStage_TranslateHits(t *testing.T) {
fieldID := FieldID(0)
fieldIDs := []FieldID{fieldID}
genRawHits := func(dataType schemapb.DataType) [][]byte {
// ids
ids := make([]int64, 0)
for i := 0; i < defaultMsgLength; i++ {
ids = append(ids, int64(i))
}
// raw data
rawData := make([][]byte, 0)
switch dataType {
case schemapb.DataType_Bool:
var buf bytes.Buffer
for i := 0; i < defaultMsgLength; i++ {
err := binary.Write(&buf, binary.LittleEndian, true)
assert.NoError(t, err)
}
rawData = append(rawData, buf.Bytes())
case schemapb.DataType_Int8:
var buf bytes.Buffer
for i := 0; i < defaultMsgLength; i++ {
err := binary.Write(&buf, binary.LittleEndian, int8(i))
assert.NoError(t, err)
}
rawData = append(rawData, buf.Bytes())
case schemapb.DataType_Int16:
var buf bytes.Buffer
for i := 0; i < defaultMsgLength; i++ {
err := binary.Write(&buf, binary.LittleEndian, int16(i))
assert.NoError(t, err)
}
rawData = append(rawData, buf.Bytes())
case schemapb.DataType_Int32:
var buf bytes.Buffer
for i := 0; i < defaultMsgLength; i++ {
err := binary.Write(&buf, binary.LittleEndian, int32(i))
assert.NoError(t, err)
}
rawData = append(rawData, buf.Bytes())
case schemapb.DataType_Int64:
var buf bytes.Buffer
for i := 0; i < defaultMsgLength; i++ {
err := binary.Write(&buf, binary.LittleEndian, int64(i))
assert.NoError(t, err)
}
rawData = append(rawData, buf.Bytes())
case schemapb.DataType_Float:
var buf bytes.Buffer
for i := 0; i < defaultMsgLength; i++ {
err := binary.Write(&buf, binary.LittleEndian, float32(i))
assert.NoError(t, err)
}
rawData = append(rawData, buf.Bytes())
case schemapb.DataType_Double:
var buf bytes.Buffer
for i := 0; i < defaultMsgLength; i++ {
err := binary.Write(&buf, binary.LittleEndian, float64(i))
assert.NoError(t, err)
}
rawData = append(rawData, buf.Bytes())
}
hit := &milvuspb.Hits{
IDs: ids,
RowData: rawData,
}
hits := []*milvuspb.Hits{hit}
rawHits := make([][]byte, 0)
for _, h := range hits {
rawHit, err := proto.Marshal(h)
assert.NoError(t, err)
rawHits = append(rawHits, rawHit)
}
return rawHits
}
genSchema := func(dataType schemapb.DataType) *typeutil.SchemaHelper {
schema := &schemapb.CollectionSchema{
Name: defaultCollectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
genConstantField(constFieldParam{
id: fieldID,
dataType: dataType,
}),
},
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
return schemaHelper
}
t.Run("test bool field", func(t *testing.T) {
dataType := schemapb.DataType_Bool
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.NoError(t, err)
})
t.Run("test int8 field", func(t *testing.T) {
dataType := schemapb.DataType_Int8
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.NoError(t, err)
})
t.Run("test int16 field", func(t *testing.T) {
dataType := schemapb.DataType_Int16
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.NoError(t, err)
})
t.Run("test int32 field", func(t *testing.T) {
dataType := schemapb.DataType_Int32
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.NoError(t, err)
})
t.Run("test int64 field", func(t *testing.T) {
dataType := schemapb.DataType_Int64
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.NoError(t, err)
})
t.Run("test float field", func(t *testing.T) {
dataType := schemapb.DataType_Float
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.NoError(t, err)
})
t.Run("test double field", func(t *testing.T) {
dataType := schemapb.DataType_Double
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.NoError(t, err)
})
t.Run("test field with error type", func(t *testing.T) {
dataType := schemapb.DataType_FloatVector
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
assert.Error(t, err)
})
}

View File

@ -30,11 +30,6 @@ import (
"github.com/milvus-io/milvus/internal/types"
)
const ctxTimeInMillisecond = 5000
const debug = false
const defaultPartitionID = UniqueID(2021)
type queryCoordMock struct {
types.QueryCoord
}

View File

@ -209,3 +209,31 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
err = node.Stop()
assert.NoError(t, err)
}
func TestQueryService_addQueryCollection(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
his, err := genSimpleHistorical(ctx)
assert.NoError(t, err)
str, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
fac, err := genFactory()
assert.NoError(t, err)
// start search service
qs := newQueryService(ctx, his, str, fac)
assert.NotNil(t, qs)
qs.addQueryCollection(defaultCollectionID)
assert.Len(t, qs.queryCollections, 1)
qs.addQueryCollection(defaultCollectionID)
assert.Len(t, qs.queryCollections, 1)
const invalidCollectionID = 10000
qs.addQueryCollection(invalidCollectionID)
assert.Len(t, qs.queryCollections, 2)
}

View File

@ -573,6 +573,7 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 {
return int64(offset)
}
// TODO: remove reference of slice
func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps *[]Timestamp, records *[]*commonpb.Blob) error {
/*
CStatus

View File

@ -13,14 +13,59 @@ package querynode
import (
"context"
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
)
func TestSegmentLoader_loadSegment(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
historical, err := genSimpleHistorical(ctx)
assert.NoError(t, err)
err = historical.replica.removeSegment(defaultSegmentID)
assert.NoError(t, err)
kv, err := genEtcdKV()
assert.NoError(t, err)
loader := newSegmentLoader(ctx, nil, nil, historical.replica, kv)
assert.NotNil(t, loader)
schema, _ := genSimpleSchema()
fieldBinlog, err := saveSimpleBinLog(ctx)
assert.NoError(t, err)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
NodeID: 0,
Schema: schema,
LoadCondition: querypb.TriggerCondition_grpcRequest,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
BinlogPaths: fieldBinlog,
},
},
}
err = loader.loadSegment(req, true)
assert.Error(t, err)
}
func TestSegmentLoader_CheckSegmentMemory(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

View File

@ -12,6 +12,7 @@
package querynode
import (
"context"
"encoding/binary"
"log"
"math"
@ -624,7 +625,7 @@ func TestSegment_ConcurrentOperation(t *testing.T) {
assert.Equal(t, collection.ID(), collectionID)
wg := sync.WaitGroup{}
for i := 0; i < 1000; i++ {
for i := 0; i < 100; i++ {
segmentID := UniqueID(i)
segment := newSegment(collection, segmentID, partitionID, collectionID, "", segmentTypeSealed, true)
assert.Equal(t, segmentID, segment.segmentID)
@ -644,3 +645,109 @@ func TestSegment_ConcurrentOperation(t *testing.T) {
wg.Wait()
deleteCollection(collection)
}
func TestSegment_indexInfoTest(t *testing.T) {
t.Run("Test_valid", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h, err := genSimpleHistorical(ctx)
assert.NoError(t, err)
seg, err := h.replica.getSegmentByID(defaultSegmentID)
assert.NoError(t, err)
fieldID := simpleVecField.id
err = seg.setIndexInfo(fieldID, &indexInfo{})
assert.NoError(t, err)
indexName := "query-node-test-index"
err = seg.setIndexName(fieldID, indexName)
assert.NoError(t, err)
name := seg.getIndexName(fieldID)
assert.Equal(t, indexName, name)
indexParam := make(map[string]string)
indexParam["index_type"] = "IVF_PQ"
indexParam["index_mode"] = "cpu"
err = seg.setIndexParam(fieldID, indexParam)
assert.NoError(t, err)
param := seg.getIndexParams(fieldID)
assert.Equal(t, len(indexParam), len(param))
assert.Equal(t, indexParam["index_type"], param["index_type"])
assert.Equal(t, indexParam["index_mode"], param["index_mode"])
indexPaths := []string{"query-node-test-index-path"}
err = seg.setIndexPaths(fieldID, indexPaths)
assert.NoError(t, err)
paths := seg.getIndexPaths(fieldID)
assert.Equal(t, len(indexPaths), len(paths))
assert.Equal(t, indexPaths[0], paths[0])
indexID := UniqueID(0)
err = seg.setIndexID(fieldID, indexID)
assert.NoError(t, err)
id := seg.getIndexID(fieldID)
assert.Equal(t, indexID, id)
buildID := UniqueID(0)
err = seg.setBuildID(fieldID, buildID)
assert.NoError(t, err)
id = seg.getBuildID(fieldID)
assert.Equal(t, buildID, id)
// TODO: add match index test
})
t.Run("Test_invalid", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h, err := genSimpleHistorical(ctx)
assert.NoError(t, err)
seg, err := h.replica.getSegmentByID(defaultSegmentID)
assert.NoError(t, err)
fieldID := simpleVecField.id
indexName := "query-node-test-index"
err = seg.setIndexName(fieldID, indexName)
assert.Error(t, err)
name := seg.getIndexName(fieldID)
assert.Equal(t, "", name)
indexParam := make(map[string]string)
indexParam["index_type"] = "IVF_PQ"
indexParam["index_mode"] = "cpu"
err = seg.setIndexParam(fieldID, indexParam)
assert.Error(t, err)
err = seg.setIndexParam(fieldID, nil)
assert.Error(t, err)
param := seg.getIndexParams(fieldID)
assert.Nil(t, param)
indexPaths := []string{"query-node-test-index-path"}
err = seg.setIndexPaths(fieldID, indexPaths)
assert.Error(t, err)
paths := seg.getIndexPaths(fieldID)
assert.Nil(t, paths)
indexID := UniqueID(0)
err = seg.setIndexID(fieldID, indexID)
assert.Error(t, err)
id := seg.getIndexID(fieldID)
assert.Equal(t, int64(-1), id)
buildID := UniqueID(0)
err = seg.setBuildID(fieldID, buildID)
assert.Error(t, err)
id = seg.getBuildID(fieldID)
assert.Equal(t, int64(-1), id)
seg.indexInfos = nil
err = seg.setIndexInfo(fieldID, &indexInfo{})
assert.Error(t, err)
})
}

View File

@ -0,0 +1,83 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package querynode
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestStreaming_streaming(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
streaming, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
defer streaming.close()
streaming.start()
}
func TestStreaming_search(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
streaming, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
defer streaming.close()
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
assert.NoError(t, err)
res, err := streaming.search(searchReqs,
defaultCollectionID,
[]UniqueID{defaultPartitionID},
defaultVChannel,
plan,
Timestamp(0))
assert.NoError(t, err)
assert.Len(t, res, 1)
}
func TestStreaming_retrieve(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
streaming, err := genSimpleStreaming(ctx)
assert.NoError(t, err)
defer streaming.close()
plan, err := genSimpleRetrievePlan()
assert.NoError(t, err)
insertMsg, err := genSimpleInsertMsg()
assert.NoError(t, err)
segment, err := streaming.replica.getSegmentByID(defaultSegmentID)
assert.NoError(t, err)
offset, err := segment.segmentPreInsert(len(insertMsg.RowIDs))
assert.NoError(t, err)
err = segment.segmentInsert(offset, &insertMsg.RowIDs, &insertMsg.Timestamps, &insertMsg.RowData)
assert.NoError(t, err)
res, ids, err := streaming.retrieve(defaultCollectionID, []UniqueID{defaultPartitionID}, plan)
assert.NoError(t, err)
assert.Len(t, res, 1)
assert.Len(t, ids, 1)
//assert.Error(t, err)
//assert.Len(t, res, 0)
//assert.Len(t, ids, 0)
}

View File

@ -11,28 +11,11 @@
package querynode
type deleteNode struct {
baseNode
deleteMsg deleteMsg
}
func (dNode *deleteNode) Name() string {
return "dNode"
}
func (dNode *deleteNode) Operate(in []*Msg) []*Msg {
return in
}
func newDeleteNode() *deleteNode {
maxQueueLength := Params.FlowGraphMaxQueueLength
maxParallelism := Params.FlowGraphMaxParallelism
baseNode := baseNode{}
baseNode.SetMaxQueueLength(maxQueueLength)
baseNode.SetMaxParallelism(maxParallelism)
return &deleteNode{
baseNode: baseNode,
}
import (
"testing"
)
// TODO: add task ut
func TestTask_watchDmChannelsTask(t *testing.T) {
}

View File

@ -26,6 +26,7 @@ type (
UniqueID = typeutil.UniqueID
// Timestamp is timestamp
Timestamp = typeutil.Timestamp
FieldID = int64
// IntPrimaryKey is the primary key of int type
IntPrimaryKey = typeutil.IntPrimaryKey
// DSL is the Domain Specific Language