Add ut for queryCoord (#7212)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/7314/head
xige-16 2021-08-26 14:17:54 +08:00 committed by GitHub
parent 8701c477e2
commit 055d94ede1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1144 additions and 431 deletions

View File

@ -17,7 +17,9 @@ import (
"fmt"
"time"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
@ -29,8 +31,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
)
type Client struct {

View File

@ -20,14 +20,13 @@ import (
"strconv"
"sync"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/sessionutil"
)
@ -66,20 +65,24 @@ type Cluster interface {
printMeta()
}
type newQueryNodeFn func(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error)
type queryNodeCluster struct {
client *etcdkv.EtcdKV
sync.RWMutex
clusterMeta Meta
nodes map[int64]Node
newNodeFn newQueryNodeFn
}
func newQueryNodeCluster(clusterMeta Meta, kv *etcdkv.EtcdKV) (*queryNodeCluster, error) {
func newQueryNodeCluster(clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn) (*queryNodeCluster, error) {
nodes := make(map[int64]Node)
c := &queryNodeCluster{
client: kv,
clusterMeta: clusterMeta,
nodes: nodes,
newNodeFn: newNodeFn,
}
err := c.reloadFromKV()
if err != nil {
@ -428,7 +431,11 @@ func (c *queryNodeCluster) registerNode(ctx context.Context, session *sessionuti
if err != nil {
return err
}
c.nodes[id] = newQueryNode(ctx, session.Address, id, c.client)
c.nodes[id], err = c.newNodeFn(ctx, session.Address, id, c.client)
if err != nil {
log.Debug("RegisterNode: create a new query node failed", zap.Int64("nodeID", id), zap.Error(err))
return err
}
log.Debug("RegisterNode: create a new query node", zap.Int64("nodeID", id), zap.String("address", session.Address))
go func() {
@ -480,6 +487,9 @@ func (c *queryNodeCluster) removeNodeInfo(nodeID int64) error {
}
func (c *queryNodeCluster) stopNode(nodeID int64) {
c.Lock()
defer c.Unlock()
if node, ok := c.nodes[nodeID]; ok {
node.stop()
log.Debug("StopNode: queryNode offline", zap.Int64("nodeID", nodeID))

View File

@ -12,11 +12,57 @@
package querycoord
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/sessionutil"
)
func TestQueryNodeCluster_getMetrics(t *testing.T) {
log.Info("TestQueryNodeCluster_getMetrics, todo")
}
func TestReloadClusterFromKV(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
cluster := &queryNodeCluster{
client: kv,
nodes: make(map[int64]Node),
newNodeFn: newQueryNodeTest,
}
kvs := make(map[string]string)
session := &sessionutil.Session{
ServerID: 100,
Address: "localhost",
}
sessionBlob, err := json.Marshal(session)
assert.Nil(t, err)
sessionKey := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, 100)
kvs[sessionKey] = string(sessionBlob)
collectionInfo := &querypb.CollectionInfo{
CollectionID: defaultCollectionID,
}
collectionBlobs := proto.MarshalTextString(collectionInfo)
nodeKey := fmt.Sprintf("%s/%d", queryNodeMetaPrefix, 100)
kvs[nodeKey] = collectionBlobs
err = kv.MultiSave(kvs)
assert.Nil(t, err)
cluster.reloadFromKV()
assert.Equal(t, 1, len(cluster.nodes))
collection := cluster.getCollectionInfosByID(context.Background(), 100)
assert.Equal(t, defaultCollectionID, collection[0].CollectionID)
}

View File

@ -559,7 +559,6 @@ func (qc *QueryCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe
return metrics, err
}
log.Debug("QueryCoord.GetMetrics failed, request metric type is not implemented yet",
zap.Int64("node_id", Params.QueryCoordID),
zap.String("req", req.Request),

View File

@ -0,0 +1,259 @@
package querycoord
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
)
func TestGrpcTask(t *testing.T) {
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
t.Run("Test LoadPartition", func(t *testing.T) {
status, err := queryCoord.LoadPartitions(ctx, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
Schema: genCollectionSchema(defaultCollectionID, false),
})
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test ShowPartitions", func(t *testing.T) {
res, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
})
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test ShowAllPartitions", func(t *testing.T) {
res, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
},
CollectionID: defaultCollectionID,
})
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test GetPartitionStates", func(t *testing.T) {
res, err := queryCoord.GetPartitionStates(ctx, &querypb.GetPartitionStatesRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_GetPartitionStatistics,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
})
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test LoadCollection", func(t *testing.T) {
status, err := queryCoord.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
})
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test ShowCollections", func(t *testing.T) {
res, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
},
CollectionIDs: []UniqueID{defaultCollectionID},
})
assert.Equal(t, 100, int(res.InMemoryPercentages[0]))
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test ShowAllCollections", func(t *testing.T) {
res, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
},
})
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test GetSegmentInfo", func(t *testing.T) {
res, err := queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
SegmentIDs: []UniqueID{defaultSegmentID},
})
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test ReleasePartition", func(t *testing.T) {
status, err := queryCoord.ReleasePartitions(ctx, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
})
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test ReleaseCollection", func(t *testing.T) {
status, err := queryCoord.ReleaseCollection(ctx, &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
CollectionID: defaultCollectionID,
})
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test GetStatisticsChannel", func(t *testing.T) {
_, err = queryCoord.GetStatisticsChannel(ctx)
assert.Nil(t, err)
})
t.Run("Test GetTimeTickChannel", func(t *testing.T) {
_, err = queryCoord.GetTimeTickChannel(ctx)
assert.Nil(t, err)
})
t.Run("Test GetComponentStates", func(t *testing.T) {
states, err := queryCoord.GetComponentStates(ctx)
assert.Equal(t, states.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Equal(t, states.State.StateCode, internalpb.StateCode_Healthy)
assert.Nil(t, err)
})
t.Run("Test CreateQueryChannel", func(t *testing.T) {
res, err := queryCoord.CreateQueryChannel(ctx, &querypb.CreateQueryChannelRequest{
CollectionID: defaultCollectionID,
})
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
assert.Nil(t, err)
})
t.Run("Test GetMetrics", func(t *testing.T) {
metricReq := make(map[string]string)
metricReq[metricsinfo.MetricTypeKey] = "system_info"
req, err := json.Marshal(metricReq)
assert.Nil(t, err)
res, err := queryCoord.GetMetrics(ctx, &milvuspb.GetMetricsRequest{
Base: &commonpb.MsgBase{},
Request: string(req),
})
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
})
//nodes, err := queryCoord.cluster.getOnServiceNodes()
//assert.Nil(t, err)
err = node.stop()
//assert.Nil(t, err)
//allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
//assert.Equal(t, allNodeOffline, true)
queryCoord.Stop()
}
func TestLoadBalanceTask(t *testing.T) {
baseCtx := context.Background()
queryCoord, err := startQueryCoord(baseCtx)
assert.Nil(t, err)
queryNode1, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
queryNode2, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
time.Sleep(time.Second)
res, err := queryCoord.LoadCollection(baseCtx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
})
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, res.ErrorCode)
time.Sleep(time.Second)
for {
collectionInfo := queryCoord.meta.showCollections()
if collectionInfo[0].InMemoryPercentage == 100 {
break
}
}
nodeID := queryNode1.queryNodeID
queryCoord.cluster.stopNode(nodeID)
loadBalanceSegment := &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
SourceID: nodeID,
},
SourceNodeIDs: []int64{nodeID},
BalanceReason: querypb.TriggerCondition_nodeDown,
}
loadBalanceTask := &LoadBalanceTask{
BaseTask: BaseTask{
ctx: baseCtx,
Condition: NewTaskCondition(baseCtx),
triggerCondition: querypb.TriggerCondition_nodeDown,
},
LoadBalanceRequest: loadBalanceSegment,
rootCoord: queryCoord.rootCoordClient,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
queryCoord.scheduler.Enqueue([]task{loadBalanceTask})
res, err = queryCoord.ReleaseCollection(baseCtx, &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
CollectionID: defaultCollectionID,
})
assert.Nil(t, err)
queryNode1.stop()
queryNode2.stop()
queryCoord.Stop()
}

View File

@ -57,8 +57,6 @@ type Meta interface {
getPartitionStatesByID(collectionID UniqueID, partitionID UniqueID) (*querypb.PartitionStates, error)
hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error)
getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error)
getDmChannelsByNodeID(collectionID UniqueID, nodeID int64) ([]string, error)
addDmChannel(collectionID UniqueID, nodeID int64, channels []string) error
removeDmChannel(collectionID UniqueID, nodeID int64, channels []string) error
@ -488,40 +486,6 @@ func (m *MetaReplica) releasePartition(collectionID UniqueID, partitionID Unique
return nil
}
func (m *MetaReplica) hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error) {
m.RLock()
defer m.RUnlock()
if info, ok := m.collectionInfos[collectionID]; ok {
channelInfos := info.ChannelInfos
for _, channelInfo := range channelInfos {
for _, channel := range channelInfo.ChannelIDs {
if channel == channelID {
return true, nil
}
}
}
return false, nil
}
return false, errors.New("hasWatchedDmChannel: can't find collection in collectionInfos")
}
func (m *MetaReplica) getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error) {
m.RLock()
defer m.RUnlock()
if info, ok := m.collectionInfos[collectionID]; ok {
channels := make([]string, 0)
for _, channelsInfo := range info.ChannelInfos {
channels = append(channels, channelsInfo.ChannelIDs...)
}
return channels, nil
}
return nil, errors.New("getDmChannelsByCollectionID: can't find collection in collectionInfos")
}
func (m *MetaReplica) getDmChannelsByNodeID(collectionID UniqueID, nodeID int64) ([]string, error) {
m.RLock()
defer m.RUnlock()

View File

@ -12,15 +12,19 @@
package querycoord
import (
"fmt"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
func TestReplica_Release(t *testing.T) {
refreshParams()
etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
meta, err := newMeta(etcdKV)
@ -48,3 +52,53 @@ func TestReplica_Release(t *testing.T) {
assert.Equal(t, 0, len(collections))
meta.releaseCollection(1)
}
func TestReloadMetaFromKV(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
meta := &MetaReplica{
client: kv,
collectionInfos: map[UniqueID]*querypb.CollectionInfo{},
segmentInfos: map[UniqueID]*querypb.SegmentInfo{},
queryChannelInfos: map[UniqueID]*querypb.QueryChannelInfo{},
}
kvs := make(map[string]string)
collectionInfo := &querypb.CollectionInfo{
CollectionID: defaultCollectionID,
}
collectionBlobs := proto.MarshalTextString(collectionInfo)
collectionKey := fmt.Sprintf("%s/%d", collectionMetaPrefix, defaultCollectionID)
kvs[collectionKey] = collectionBlobs
segmentInfo := &querypb.SegmentInfo{
SegmentID: defaultSegmentID,
}
segmentBlobs := proto.MarshalTextString(segmentInfo)
segmentKey := fmt.Sprintf("%s/%d", segmentMetaPrefix, defaultSegmentID)
kvs[segmentKey] = segmentBlobs
queryChannelInfo := &querypb.QueryChannelInfo{
CollectionID: defaultCollectionID,
}
queryChannelBlobs := proto.MarshalTextString(queryChannelInfo)
queryChannelKey := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, defaultCollectionID)
kvs[queryChannelKey] = queryChannelBlobs
err = kv.MultiSave(kvs)
assert.Nil(t, err)
err = meta.reloadFromKV()
assert.Nil(t, err)
assert.Equal(t, 1, len(meta.collectionInfos))
assert.Equal(t, 1, len(meta.segmentInfos))
assert.Equal(t, 1, len(meta.queryChannelInfos))
_, ok := meta.collectionInfos[defaultCollectionID]
assert.Equal(t, true, ok)
_, ok = meta.segmentInfos[defaultSegmentID]
assert.Equal(t, true, ok)
_, ok = meta.queryChannelInfos[defaultCollectionID]
assert.Equal(t, true, ok)
}

View File

@ -16,17 +16,13 @@ import (
"errors"
"fmt"
"math/rand"
"net"
"path"
"strconv"
"sync"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/kv"
minioKV "github.com/milvus-io/milvus/internal/kv/minio"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
@ -34,13 +30,10 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
qn "github.com/milvus-io/milvus/internal/querynode"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
)
const (
@ -373,163 +366,3 @@ func newIndexCoordMock() *indexCoordMock {
func (c *indexCoordMock) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFilePathsRequest) (*indexpb.GetIndexFilePathsResponse, error) {
return nil, errors.New("get index file path fail")
}
type queryNodeServerMock struct {
querypb.QueryNodeServer
ctx context.Context
cancel context.CancelFunc
queryNode *qn.QueryNode
grpcErrChan chan error
grpcServer *grpc.Server
addQueryChannels func() (*commonpb.Status, error)
watchDmChannels func() (*commonpb.Status, error)
loadSegment func() (*commonpb.Status, error)
releaseCollection func() (*commonpb.Status, error)
releasePartition func() (*commonpb.Status, error)
releaseSegment func() (*commonpb.Status, error)
}
func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
ctx1, cancel := context.WithCancel(ctx)
factory := msgstream.NewPmsFactory()
return &queryNodeServerMock{
ctx: ctx,
cancel: cancel,
queryNode: qn.NewQueryNode(ctx1, factory),
grpcErrChan: make(chan error),
addQueryChannels: returnSuccessResult,
watchDmChannels: returnSuccessResult,
loadSegment: returnSuccessResult,
releaseCollection: returnSuccessResult,
releasePartition: returnSuccessResult,
releaseSegment: returnSuccessResult,
}
}
func (qs *queryNodeServerMock) init() error {
qn.Params.Init()
qn.Params.MetaRootPath = Params.MetaRootPath
qn.Params.QueryNodeIP = funcutil.GetLocalIP()
grpcPort := Params.Port
go func() {
var lis net.Listener
var err error
err = retry.Do(qs.ctx, func() error {
addr := ":" + strconv.Itoa(grpcPort)
lis, err = net.Listen("tcp", addr)
if err == nil {
qn.Params.QueryNodePort = int64(lis.Addr().(*net.TCPAddr).Port)
} else {
// set port=0 to get next available port
grpcPort = 0
}
return err
}, retry.Attempts(10))
if err != nil {
log.Error(err.Error())
}
qs.grpcServer = grpc.NewServer()
querypb.RegisterQueryNodeServer(qs.grpcServer, qs)
if err = qs.grpcServer.Serve(lis); err != nil {
log.Error(err.Error())
}
}()
rootCoord := newRootCoordMock()
indexCoord := newIndexCoordMock()
qs.queryNode.SetRootCoord(rootCoord)
qs.queryNode.SetIndexCoord(indexCoord)
err := qs.queryNode.Init()
if err != nil {
return err
}
if err = qs.queryNode.Register(); err != nil {
return err
}
return nil
}
func (qs *queryNodeServerMock) start() error {
return qs.queryNode.Start()
}
func (qs *queryNodeServerMock) stop() error {
qs.cancel()
if qs.grpcServer != nil {
qs.grpcServer.GracefulStop()
}
err := qs.queryNode.Stop()
if err != nil {
return err
}
return nil
}
func (qs *queryNodeServerMock) run() error {
if err := qs.init(); err != nil {
return err
}
if err := qs.start(); err != nil {
return err
}
return nil
}
func (qs *queryNodeServerMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
return qs.addQueryChannels()
}
func (qs *queryNodeServerMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return qs.watchDmChannels()
}
func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
return qs.loadSegment()
}
func (qs *queryNodeServerMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return qs.releaseCollection()
}
func (qs *queryNodeServerMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return qs.releasePartition()
}
func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return qs.releaseSegment()
}
func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) {
node := newQueryNodeServerMock(ctx)
err := node.run()
if err != nil {
return nil, err
}
return node, nil
}
func returnSuccessResult() (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func returnFailedResult() (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}, errors.New("query node do task failed")
}

View File

@ -0,0 +1,136 @@
package querycoord
import (
"context"
"fmt"
"time"
"google.golang.org/grpc"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
type queryNodeClientMock struct {
ctx context.Context
cancel context.CancelFunc
grpcClient querypb.QueryNodeClient
conn *grpc.ClientConn
addr string
}
func newQueryNodeTest(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error) {
collectionInfo := make(map[UniqueID]*querypb.CollectionInfo)
watchedChannels := make(map[UniqueID]*querypb.QueryChannelInfo)
childCtx, cancel := context.WithCancel(ctx)
client, err := newQueryNodeClientMock(childCtx, address)
if err != nil {
cancel()
return nil, err
}
node := &queryNode{
ctx: childCtx,
cancel: cancel,
id: id,
address: address,
client: client,
kvClient: kv,
collectionInfos: collectionInfo,
watchedQueryChannels: watchedChannels,
onService: false,
}
return node, nil
}
func newQueryNodeClientMock(ctx context.Context, addr string) (*queryNodeClientMock, error) {
if addr == "" {
return nil, fmt.Errorf("addr is empty")
}
ctx, cancel := context.WithCancel(ctx)
return &queryNodeClientMock{
ctx: ctx,
cancel: cancel,
addr: addr,
}, nil
}
func (client *queryNodeClientMock) Init() error {
ctx, cancel := context.WithTimeout(client.ctx, time.Second*2)
defer cancel()
conn, err := grpc.DialContext(ctx, client.addr, grpc.WithInsecure(), grpc.WithBlock())
if err != nil {
return err
}
client.conn = conn
log.Debug("QueryNodeClient try connect success")
client.grpcClient = querypb.NewQueryNodeClient(conn)
return nil
}
func (client *queryNodeClientMock) Start() error {
return nil
}
func (client *queryNodeClientMock) Stop() error {
client.cancel()
return client.conn.Close()
}
func (client *queryNodeClientMock) Register() error {
return nil
}
func (client *queryNodeClientMock) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
return client.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
}
func (client *queryNodeClientMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return client.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
}
func (client *queryNodeClientMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return client.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
}
func (client *queryNodeClientMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
return client.grpcClient.AddQueryChannel(ctx, req)
}
func (client *queryNodeClientMock) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQueryChannelRequest) (*commonpb.Status, error) {
return client.grpcClient.RemoveQueryChannel(ctx, req)
}
func (client *queryNodeClientMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return client.grpcClient.WatchDmChannels(ctx, req)
}
func (client *queryNodeClientMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
return client.grpcClient.LoadSegments(ctx, req)
}
func (client *queryNodeClientMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return client.grpcClient.ReleaseCollection(ctx, req)
}
func (client *queryNodeClientMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return client.grpcClient.ReleasePartitions(ctx, req)
}
func (client *queryNodeClientMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return client.grpcClient.ReleaseSegments(ctx, req)
}
func (client *queryNodeClientMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return client.grpcClient.GetSegmentInfo(ctx, req)
}
func (client *queryNodeClientMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
return client.grpcClient.GetMetrics(ctx, req)
}

View File

@ -0,0 +1,196 @@
package querycoord
import (
"context"
"errors"
"net"
"strconv"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type queryNodeServerMock struct {
querypb.QueryNodeServer
ctx context.Context
cancel context.CancelFunc
session *sessionutil.Session
grpcErrChan chan error
grpcServer *grpc.Server
queryNodeIP string
queryNodePort int64
queryNodeID int64
addQueryChannels func() (*commonpb.Status, error)
watchDmChannels func() (*commonpb.Status, error)
loadSegment func() (*commonpb.Status, error)
releaseCollection func() (*commonpb.Status, error)
releasePartition func() (*commonpb.Status, error)
releaseSegment func() (*commonpb.Status, error)
}
func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
ctx1, cancel := context.WithCancel(ctx)
return &queryNodeServerMock{
ctx: ctx1,
cancel: cancel,
grpcErrChan: make(chan error),
addQueryChannels: returnSuccessResult,
watchDmChannels: returnSuccessResult,
loadSegment: returnSuccessResult,
releaseCollection: returnSuccessResult,
releasePartition: returnSuccessResult,
releaseSegment: returnSuccessResult,
}
}
func (qs *queryNodeServerMock) Register() error {
log.Debug("query node session info", zap.String("metaPath", Params.MetaRootPath), zap.Strings("etcdEndPoints", Params.EtcdEndpoints))
qs.session = sessionutil.NewSession(qs.ctx, Params.MetaRootPath, Params.EtcdEndpoints)
qs.session.Init(typeutil.QueryNodeRole, qs.queryNodeIP+":"+strconv.FormatInt(qs.queryNodePort, 10), false)
qs.queryNodeID = qs.session.ServerID
log.Debug("query nodeID", zap.Int64("nodeID", qs.queryNodeID))
log.Debug("query node address", zap.String("address", qs.session.Address))
return nil
}
func (qs *queryNodeServerMock) init() error {
qs.queryNodeIP = funcutil.GetLocalIP()
grpcPort := Params.Port
go func() {
var lis net.Listener
var err error
err = retry.Do(qs.ctx, func() error {
addr := ":" + strconv.Itoa(grpcPort)
lis, err = net.Listen("tcp", addr)
if err == nil {
qs.queryNodePort = int64(lis.Addr().(*net.TCPAddr).Port)
} else {
// set port=0 to get next available port
grpcPort = 0
}
return err
}, retry.Attempts(10))
if err != nil {
qs.grpcErrChan <- err
}
qs.grpcServer = grpc.NewServer()
querypb.RegisterQueryNodeServer(qs.grpcServer, qs)
go funcutil.CheckGrpcReady(qs.ctx, qs.grpcErrChan)
if err = qs.grpcServer.Serve(lis); err != nil {
qs.grpcErrChan <- err
}
}()
err := <-qs.grpcErrChan
if err != nil {
return err
}
if err := qs.Register(); err != nil {
return err
}
return nil
}
func (qs *queryNodeServerMock) start() error {
return nil
}
func (qs *queryNodeServerMock) stop() error {
qs.cancel()
if qs.grpcServer != nil {
qs.grpcServer.GracefulStop()
}
return nil
}
func (qs *queryNodeServerMock) run() error {
if err := qs.init(); err != nil {
return err
}
if err := qs.start(); err != nil {
return err
}
return nil
}
func (qs *queryNodeServerMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
return qs.addQueryChannels()
}
func (qs *queryNodeServerMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return qs.watchDmChannels()
}
func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
return qs.loadSegment()
}
func (qs *queryNodeServerMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return qs.releaseCollection()
}
func (qs *queryNodeServerMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return qs.releasePartition()
}
func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return qs.releaseSegment()
}
func (qs *queryNodeServerMock) GetSegmentInfo(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return &querypb.GetSegmentInfoResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}, nil
}
func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}, nil
}
func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) {
node := newQueryNodeServerMock(ctx)
err := node.run()
if err != nil {
return nil, err
}
return node, nil
}
func returnSuccessResult() (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil
}
func returnFailedResult() (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}, errors.New("query node do task failed")
}

View File

@ -52,6 +52,7 @@ type QueryCoord struct {
queryCoordID uint64
meta Meta
cluster *queryNodeCluster
newNodeFn newQueryNodeFn
scheduler *TaskScheduler
dataCoordClient types.DataCoord
@ -98,7 +99,7 @@ func (qc *QueryCoord) Init() error {
return err
}
qc.cluster, err = newQueryNodeCluster(qc.meta, qc.kvClient)
qc.cluster, err = newQueryNodeCluster(qc.meta, qc.kvClient, qc.newNodeFn)
if err != nil {
log.Error("query coordinator init cluster failed", zap.Error(err))
return err
@ -160,6 +161,7 @@ func NewQueryCoord(ctx context.Context, factory msgstream.Factory) (*QueryCoord,
loopCtx: ctx1,
loopCancel: cancel,
msFactory: factory,
newNodeFn: newQueryNode,
}
service.UpdateStateCode(internalpb.StateCode_Abnormal)

View File

@ -14,91 +14,61 @@ package querycoord
import (
"context"
"math/rand"
"os"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
func setup() {
Params.Init()
}
func refreshParams() {
rand.Seed(time.Now().UnixNano())
suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63(), 10)
Params.StatsChannelName = Params.StatsChannelName + suffix
Params.TimeTickChannelName = Params.TimeTickChannelName + suffix
Params.MetaRootPath = Params.MetaRootPath + suffix
}
func refreshChannelNames() {
suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63n(1000000), 10)
Params.StatsChannelName = Params.StatsChannelName + suffix
Params.TimeTickChannelName = Params.TimeTickChannelName + suffix
}
func TestMain(m *testing.M) {
/*
setup()
//refreshChannelNames()
exitCode := m.Run()
os.Exit(exitCode)
*/
setup()
//refreshChannelNames()
exitCode := m.Run()
os.Exit(exitCode)
}
func TestQueryCoord_Init(t *testing.T) {
ctx := context.Background()
msFactory := msgstream.NewPmsFactory()
service, err := NewQueryCoord(context.Background(), msFactory)
assert.Nil(t, err)
service.Register()
service.Init()
service.Start()
func NewQueryCoordTest(ctx context.Context, factory msgstream.Factory) (*QueryCoord, error) {
refreshParams()
rand.Seed(time.Now().UnixNano())
queryChannels := make([]*queryChannelInfo, 0)
channelID := len(queryChannels)
searchPrefix := Params.SearchChannelPrefix
searchResultPrefix := Params.SearchResultChannelPrefix
allocatedQueryChannel := searchPrefix + "-" + strconv.FormatInt(int64(channelID), 10)
allocatedQueryResultChannel := searchResultPrefix + "-" + strconv.FormatInt(int64(channelID), 10)
t.Run("Test Get statistics channel", func(t *testing.T) {
response, err := service.GetStatisticsChannel(ctx)
assert.Nil(t, err)
assert.Equal(t, response.Value, "query-node-stats")
queryChannels = append(queryChannels, &queryChannelInfo{
requestChannel: allocatedQueryChannel,
responseChannel: allocatedQueryResultChannel,
})
t.Run("Test Get timeTick channel", func(t *testing.T) {
response, err := service.GetTimeTickChannel(ctx)
assert.Nil(t, err)
assert.Equal(t, response.Value, "queryTimeTick")
})
ctx1, cancel := context.WithCancel(ctx)
service := &QueryCoord{
loopCtx: ctx1,
loopCancel: cancel,
msFactory: factory,
newNodeFn: newQueryNodeTest,
}
service.Stop()
service.UpdateStateCode(internalpb.StateCode_Abnormal)
log.Debug("query coordinator", zap.Any("queryChannels", queryChannels))
return service, nil
}
//func TestQueryCoord_load(t *testing.T) {
// ctx := context.Background()
// msFactory := msgstream.NewPmsFactory()
// service, err := NewQueryCoord(context.Background(), msFactory)
// assert.Nil(t, err)
// service.Init()
// service.Start()
// service.SetRootCoord(newRootCoordMock())
// service.SetDataCoord(NewDataMock())
// registerNodeRequest := &querypb.RegisterNodeRequest{
// Address: &commonpb.Address{},
// }
// service.RegisterNode(ctx, registerNodeRequest)
//
// t.Run("Test LoadSegment", func(t *testing.T) {
// loadCollectionRequest := &querypb.LoadCollectionRequest{
// CollectionID: 1,
// }
// response, err := service.LoadCollection(ctx, loadCollectionRequest)
// assert.Nil(t, err)
// assert.Equal(t, response.ErrorCode, commonpb.ErrorCode_Success)
// })
//
// t.Run("Test LoadPartition", func(t *testing.T) {
// loadPartitionRequest := &querypb.LoadPartitionsRequest{
// CollectionID: 1,
// PartitionIDs: []UniqueID{1},
// }
// response, err := service.LoadPartitions(ctx, loadPartitionRequest)
// assert.Nil(t, err)
// assert.Equal(t, response.ErrorCode, commonpb.ErrorCode_Success)
// })
//}

View File

@ -37,19 +37,14 @@ type Node interface {
stop()
clearNodeInfo() error
hasCollection(collectionID UniqueID) bool
addCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) error
setCollectionInfo(info *querypb.CollectionInfo) error
getCollectionInfoByID(collectionID UniqueID) (*querypb.CollectionInfo, error)
showCollections() []*querypb.CollectionInfo
releaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) error
hasPartition(collectionID UniqueID, partitionID UniqueID) bool
addPartition(collectionID UniqueID, partitionID UniqueID) error
releasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) error
hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error)
getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error)
watchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) error
removeDmChannel(collectionID UniqueID, channels []string) error
@ -84,37 +79,38 @@ type queryNode struct {
serviceLock sync.RWMutex
}
func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) Node {
func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error) {
collectionInfo := make(map[UniqueID]*querypb.CollectionInfo)
watchedChannels := make(map[UniqueID]*querypb.QueryChannelInfo)
childCtx, cancel := context.WithCancel(ctx)
client, err := nodeclient.NewClient(childCtx, address)
if err != nil {
cancel()
return nil, err
}
node := &queryNode{
ctx: childCtx,
cancel: cancel,
id: id,
address: address,
client: client,
kvClient: kv,
collectionInfos: collectionInfo,
watchedQueryChannels: watchedChannels,
onService: false,
}
return node
return node, nil
}
func (qn *queryNode) start() error {
client, err := nodeclient.NewClient(qn.ctx, qn.address)
if err != nil {
if err := qn.client.Init(); err != nil {
return err
}
if err = client.Init(); err != nil {
return err
}
if err = client.Start(); err != nil {
if err := qn.client.Start(); err != nil {
return err
}
qn.client = client
qn.serviceLock.Lock()
qn.onService = true
qn.serviceLock.Unlock()
@ -132,32 +128,6 @@ func (qn *queryNode) stop() {
qn.cancel()
}
func (qn *queryNode) hasCollection(collectionID UniqueID) bool {
qn.RLock()
defer qn.RUnlock()
if _, ok := qn.collectionInfos[collectionID]; ok {
return true
}
return false
}
func (qn *queryNode) hasPartition(collectionID UniqueID, partitionID UniqueID) bool {
qn.RLock()
defer qn.RUnlock()
if info, ok := qn.collectionInfos[collectionID]; ok {
for _, id := range info.PartitionIDs {
if partitionID == id {
return true
}
}
}
return false
}
func (qn *queryNode) addCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) error {
qn.Lock()
defer qn.Unlock()
@ -195,16 +165,6 @@ func (qn *queryNode) setCollectionInfo(info *querypb.CollectionInfo) error {
return nil
}
func (qn *queryNode) getCollectionInfoByID(collectionID UniqueID) (*querypb.CollectionInfo, error) {
qn.Lock()
defer qn.Lock()
if _, ok := qn.collectionInfos[collectionID]; ok {
return proto.Clone(qn.collectionInfos[collectionID]).(*querypb.CollectionInfo), nil
}
return nil, errors.New("GetCollectionInfoByID: can't find collection")
}
func (qn *queryNode) showCollections() []*querypb.CollectionInfo {
qn.RLock()
defer qn.RUnlock()
@ -281,40 +241,6 @@ func (qn *queryNode) releasePartitionsInfo(collectionID UniqueID, partitionIDs [
return nil
}
func (qn *queryNode) hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error) {
qn.RLock()
defer qn.RUnlock()
if info, ok := qn.collectionInfos[collectionID]; ok {
channelInfos := info.ChannelInfos
for _, channelInfo := range channelInfos {
for _, channel := range channelInfo.ChannelIDs {
if channel == channelID {
return true, nil
}
}
}
return false, nil
}
return false, errors.New("HasWatchedDmChannel: can't find collection in collectionInfos")
}
func (qn *queryNode) getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error) {
qn.RLock()
defer qn.RUnlock()
if info, ok := qn.collectionInfos[collectionID]; ok {
channels := make([]string, 0)
for _, channelsInfo := range info.ChannelInfos {
channels = append(channels, channelsInfo.ChannelIDs...)
}
return channels, nil
}
return nil, errors.New("GetDmChannelsByCollectionID: can't find collection in collectionInfos")
}
func (qn *queryNode) addDmChannel(collectionID UniqueID, channels []string) error {
qn.Lock()
defer qn.Unlock()

View File

@ -27,7 +27,7 @@ import (
func startQueryCoord(ctx context.Context) (*QueryCoord, error) {
factory := msgstream.NewPmsFactory()
coord, err := NewQueryCoord(ctx, factory)
coord, err := NewQueryCoordTest(ctx, factory)
if err != nil {
return nil, err
}
@ -59,6 +59,31 @@ func startQueryCoord(ctx context.Context) (*QueryCoord, error) {
return coord, nil
}
//func waitQueryNodeOnline(cluster *queryNodeCluster, nodeID int64)
func waitAllQueryNodeOffline(cluster *queryNodeCluster, nodes map[int64]Node) bool {
reDoCount := 20
for {
if reDoCount <= 0 {
return false
}
allOffline := true
for nodeID := range nodes {
_, err := cluster.getNodeByID(nodeID)
if err == nil {
allOffline = false
break
}
}
if allOffline {
return true
}
log.Debug("wait all queryNode offline")
time.Sleep(time.Second)
reDoCount--
}
}
func TestQueryNode_MultiNode_stop(t *testing.T) {
baseCtx := context.Background()
@ -68,23 +93,11 @@ func TestQueryNode_MultiNode_stop(t *testing.T) {
queryNode1, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
//queryNode2, err := startQueryNodeServer(baseCtx)
//assert.Nil(t, err)
//queryNode3, err := startQueryNodeServer(baseCtx)
//assert.Nil(t, err)
//queryNode4, err := startQueryNodeServer(baseCtx)
//assert.Nil(t, err)
queryNode5, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
time.Sleep(2 * time.Second)
queryNode1.stop()
//queryNode2.stop()
//queryNode3.stop()
//queryNode4.stop()
queryCoord.LoadCollection(baseCtx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
@ -106,21 +119,8 @@ func TestQueryNode_MultiNode_stop(t *testing.T) {
assert.Nil(t, err)
queryNode5.stop()
for {
allOffline := true
for nodeID := range nodes {
_, err = queryCoord.cluster.getNodeByID(nodeID)
if err == nil {
allOffline = false
time.Sleep(time.Second)
break
}
}
if allOffline {
break
}
log.Debug("wait all queryNode offline")
}
allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
assert.Equal(t, allNodeOffline, true)
queryCoord.Stop()
}
@ -133,9 +133,6 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
queryNode1, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
//queryNode2, err := startQueryNodeServer(baseCtx)
//assert.Nil(t, err)
time.Sleep(2 * time.Second)
queryCoord.LoadCollection(baseCtx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
@ -145,13 +142,8 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
Schema: genCollectionSchema(defaultCollectionID, false),
})
queryNode1.stop()
//queryNode2.stop()
queryNode3, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
//queryNode4, err := startQueryNodeServer(baseCtx)
//assert.Nil(t, err)
//queryNode5, err := startQueryNodeServer(baseCtx)
//assert.Nil(t, err)
time.Sleep(2 * time.Second)
_, err = queryCoord.ReleaseCollection(baseCtx, &querypb.ReleaseCollectionRequest{
@ -164,24 +156,9 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
nodes, err := queryCoord.cluster.onServiceNodes()
assert.Nil(t, err)
queryNode3.stop()
//queryNode4.stop()
//queryNode5.stop()
for {
allOffline := true
for nodeID := range nodes {
_, err = queryCoord.cluster.getNodeByID(nodeID)
if err == nil {
allOffline = false
time.Sleep(time.Second)
break
}
}
if allOffline {
break
}
log.Debug("wait all queryNode offline")
}
allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
assert.Equal(t, allNodeOffline, true)
queryCoord.Stop()
}

View File

@ -2,6 +2,8 @@ package querycoord
import (
"context"
"fmt"
"strconv"
"testing"
"time"
@ -142,13 +144,8 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
time.Sleep(time.Second)
queryNode.stop()
for {
_, err = queryCoord.cluster.getNodeByID(nodeID)
if err == nil {
time.Sleep(time.Second)
break
}
}
allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
assert.Equal(t, allNodeOffline, true)
time.Sleep(time.Second)
newActiveTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix)
@ -157,27 +154,247 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
queryCoord.Stop()
}
func TestUnMarshalTask_LoadCollection(t *testing.T) {
func TestUnMarshalTask(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
taskScheduler := &TaskScheduler{}
loadTask := &LoadCollectionTask{
t.Run("Test LoadCollectionTask", func(t *testing.T) {
loadTask := &LoadCollectionTask{
LoadCollectionRequest: &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
},
}
blobs, err := loadTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadCollection", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadCollection")
value, err := kv.Load("testMarshalLoadCollection")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadCollection)
})
t.Run("Test LoadPartitionsTask", func(t *testing.T) {
loadTask := &LoadPartitionTask{
LoadPartitionsRequest: &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
},
}
blobs, err := loadTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadPartition", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadPartition")
value, err := kv.Load("testMarshalLoadPartition")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadPartitions)
})
t.Run("Test ReleaseCollectionTask", func(t *testing.T) {
releaseTask := &ReleaseCollectionTask{
ReleaseCollectionRequest: &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
},
}
blobs, err := releaseTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalReleaseCollection", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalReleaseCollection")
value, err := kv.Load("testMarshalReleaseCollection")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseCollection)
})
t.Run("Test ReleasePartitionTask", func(t *testing.T) {
releaseTask := &ReleasePartitionTask{
ReleasePartitionsRequest: &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
},
}
blobs, err := releaseTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalReleasePartition", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalReleasePartition")
value, err := kv.Load("testMarshalReleasePartition")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_ReleasePartitions)
})
t.Run("Test LoadSegmentTask", func(t *testing.T) {
loadTask := &LoadSegmentTask{
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
},
},
}
blobs, err := loadTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadSegment", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadSegment")
value, err := kv.Load("testMarshalLoadSegment")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadSegments)
})
t.Run("Test ReleaseSegmentTask", func(t *testing.T) {
releaseTask := &ReleaseSegmentTask{
ReleaseSegmentsRequest: &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseSegments,
},
},
}
blobs, err := releaseTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalReleaseSegment", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalReleaseSegment")
value, err := kv.Load("testMarshalReleaseSegment")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseSegments)
})
t.Run("Test WatchDmChannelTask", func(t *testing.T) {
watchTask := &WatchDmChannelTask{
WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDmChannels,
},
},
}
blobs, err := watchTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalWatchDmChannel", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalWatchDmChannel")
value, err := kv.Load("testMarshalWatchDmChannel")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_WatchDmChannels)
})
t.Run("Test WatchQueryChannelTask", func(t *testing.T) {
watchTask := &WatchQueryChannelTask{
AddQueryChannelRequest: &querypb.AddQueryChannelRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
},
},
}
blobs, err := watchTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalWatchQueryChannel", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalWatchQueryChannel")
value, err := kv.Load("testMarshalWatchQueryChannel")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_WatchQueryChannels)
})
t.Run("Test LoadBalanceTask", func(t *testing.T) {
loadBalanceTask := &LoadBalanceTask{
LoadBalanceRequest: &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
},
},
}
blobs, err := loadBalanceTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadBalanceTask", string(blobs))
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadBalanceTask")
value, err := kv.Load("testMarshalLoadBalanceTask")
assert.Nil(t, err)
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadBalanceSegments)
})
}
func TestReloadTaskFromKV(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
taskScheduler := &TaskScheduler{
client: kv,
triggerTaskQueue: NewTaskQueue(),
}
kvs := make(map[string]string)
triggerTask := &LoadCollectionTask{
LoadCollectionRequest: &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
Timestamp: 1,
MsgType: commonpb.MsgType_LoadCollection,
},
},
}
blobs, err := loadTask.Marshal()
triggerBlobs, err := triggerTask.Marshal()
assert.Nil(t, err)
err = kv.Save("testMarshalLoadCollection", string(blobs))
triggerTaskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, 100)
kvs[triggerTaskKey] = string(triggerBlobs)
activeTask := &LoadSegmentTask{
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
Timestamp: 2,
MsgType: commonpb.MsgType_LoadSegments,
},
},
}
activeBlobs, err := activeTask.Marshal()
assert.Nil(t, err)
defer kv.RemoveWithPrefix("testMarshalLoadCollection")
value, err := kv.Load("testMarshalLoadCollection")
activeTaskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, 101)
kvs[activeTaskKey] = string(activeBlobs)
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, 100)
kvs[stateKey] = strconv.Itoa(int(taskDone))
err = kv.MultiSave(kvs)
assert.Nil(t, err)
taskScheduler := &TaskScheduler{}
task, err := taskScheduler.unmarshalTask(value)
assert.Nil(t, err)
assert.Equal(t, task.Type(), commonpb.MsgType_LoadCollection)
taskScheduler.reloadFromKV()
task := taskScheduler.triggerTaskQueue.PopTask()
assert.Equal(t, taskDone, task.State())
assert.Equal(t, 1, len(task.GetChildTask()))
}

View File

@ -0,0 +1,124 @@
package querycoord
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
func TestTriggerTask(t *testing.T) {
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
t.Run("Test LoadCollection", func(t *testing.T) {
req := &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: defaultCollectionID,
Schema: genCollectionSchema(defaultCollectionID, false),
}
loadCollectionTask := &LoadCollectionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
LoadCollectionRequest: req,
rootCoord: queryCoord.rootCoordClient,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
err = queryCoord.scheduler.processTask(loadCollectionTask)
assert.Nil(t, err)
})
t.Run("Test ReleaseCollection", func(t *testing.T) {
req := &querypb.ReleaseCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleaseCollection,
},
CollectionID: defaultCollectionID,
}
loadCollectionTask := &ReleaseCollectionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
ReleaseCollectionRequest: req,
rootCoord: queryCoord.rootCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
err = queryCoord.scheduler.processTask(loadCollectionTask)
assert.Nil(t, err)
})
t.Run("Test LoadPartition", func(t *testing.T) {
req := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
loadCollectionTask := &LoadPartitionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
LoadPartitionsRequest: req,
dataCoord: queryCoord.dataCoordClient,
cluster: queryCoord.cluster,
meta: queryCoord.meta,
}
err = queryCoord.scheduler.processTask(loadCollectionTask)
assert.Nil(t, err)
})
t.Run("Test ReleasePartition", func(t *testing.T) {
req := &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
loadCollectionTask := &ReleasePartitionTask{
BaseTask: BaseTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
triggerCondition: querypb.TriggerCondition_grpcRequest,
},
ReleasePartitionsRequest: req,
cluster: queryCoord.cluster,
}
err = queryCoord.scheduler.processTask(loadCollectionTask)
assert.Nil(t, err)
})
//nodes, err := queryCoord.cluster.getOnServiceNodes()
//assert.Nil(t, err)
err = node.stop()
//assert.Nil(t, err)
//allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
//assert.Equal(t, allNodeOffline, true)
queryCoord.Stop()
}