mirror of https://github.com/milvus-io/milvus.git
parent
8701c477e2
commit
055d94ede1
|
@ -17,7 +17,9 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
||||
|
@ -29,8 +31,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
|
|
|
@ -20,14 +20,13 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
)
|
||||
|
@ -66,20 +65,24 @@ type Cluster interface {
|
|||
printMeta()
|
||||
}
|
||||
|
||||
type newQueryNodeFn func(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error)
|
||||
|
||||
type queryNodeCluster struct {
|
||||
client *etcdkv.EtcdKV
|
||||
|
||||
sync.RWMutex
|
||||
clusterMeta Meta
|
||||
nodes map[int64]Node
|
||||
newNodeFn newQueryNodeFn
|
||||
}
|
||||
|
||||
func newQueryNodeCluster(clusterMeta Meta, kv *etcdkv.EtcdKV) (*queryNodeCluster, error) {
|
||||
func newQueryNodeCluster(clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn) (*queryNodeCluster, error) {
|
||||
nodes := make(map[int64]Node)
|
||||
c := &queryNodeCluster{
|
||||
client: kv,
|
||||
clusterMeta: clusterMeta,
|
||||
nodes: nodes,
|
||||
newNodeFn: newNodeFn,
|
||||
}
|
||||
err := c.reloadFromKV()
|
||||
if err != nil {
|
||||
|
@ -428,7 +431,11 @@ func (c *queryNodeCluster) registerNode(ctx context.Context, session *sessionuti
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.nodes[id] = newQueryNode(ctx, session.Address, id, c.client)
|
||||
c.nodes[id], err = c.newNodeFn(ctx, session.Address, id, c.client)
|
||||
if err != nil {
|
||||
log.Debug("RegisterNode: create a new query node failed", zap.Int64("nodeID", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("RegisterNode: create a new query node", zap.Int64("nodeID", id), zap.String("address", session.Address))
|
||||
|
||||
go func() {
|
||||
|
@ -480,6 +487,9 @@ func (c *queryNodeCluster) removeNodeInfo(nodeID int64) error {
|
|||
}
|
||||
|
||||
func (c *queryNodeCluster) stopNode(nodeID int64) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if node, ok := c.nodes[nodeID]; ok {
|
||||
node.stop()
|
||||
log.Debug("StopNode: queryNode offline", zap.Int64("nodeID", nodeID))
|
||||
|
|
|
@ -12,11 +12,57 @@
|
|||
package querycoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
)
|
||||
|
||||
func TestQueryNodeCluster_getMetrics(t *testing.T) {
|
||||
log.Info("TestQueryNodeCluster_getMetrics, todo")
|
||||
}
|
||||
|
||||
func TestReloadClusterFromKV(t *testing.T) {
|
||||
refreshParams()
|
||||
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
|
||||
assert.Nil(t, err)
|
||||
cluster := &queryNodeCluster{
|
||||
client: kv,
|
||||
nodes: make(map[int64]Node),
|
||||
newNodeFn: newQueryNodeTest,
|
||||
}
|
||||
|
||||
kvs := make(map[string]string)
|
||||
session := &sessionutil.Session{
|
||||
ServerID: 100,
|
||||
Address: "localhost",
|
||||
}
|
||||
sessionBlob, err := json.Marshal(session)
|
||||
assert.Nil(t, err)
|
||||
sessionKey := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, 100)
|
||||
kvs[sessionKey] = string(sessionBlob)
|
||||
|
||||
collectionInfo := &querypb.CollectionInfo{
|
||||
CollectionID: defaultCollectionID,
|
||||
}
|
||||
collectionBlobs := proto.MarshalTextString(collectionInfo)
|
||||
nodeKey := fmt.Sprintf("%s/%d", queryNodeMetaPrefix, 100)
|
||||
kvs[nodeKey] = collectionBlobs
|
||||
|
||||
err = kv.MultiSave(kvs)
|
||||
assert.Nil(t, err)
|
||||
|
||||
cluster.reloadFromKV()
|
||||
|
||||
assert.Equal(t, 1, len(cluster.nodes))
|
||||
collection := cluster.getCollectionInfosByID(context.Background(), 100)
|
||||
assert.Equal(t, defaultCollectionID, collection[0].CollectionID)
|
||||
}
|
||||
|
|
|
@ -559,7 +559,6 @@ func (qc *QueryCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe
|
|||
|
||||
return metrics, err
|
||||
}
|
||||
|
||||
log.Debug("QueryCoord.GetMetrics failed, request metric type is not implemented yet",
|
||||
zap.Int64("node_id", Params.QueryCoordID),
|
||||
zap.String("req", req.Request),
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
package querycoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
)
|
||||
|
||||
func TestGrpcTask(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
queryCoord, err := startQueryCoord(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
node, err := startQueryNodeServer(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
t.Run("Test LoadPartition", func(t *testing.T) {
|
||||
status, err := queryCoord.LoadPartitions(ctx, &querypb.LoadPartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadPartitions,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
Schema: genCollectionSchema(defaultCollectionID, false),
|
||||
})
|
||||
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test ShowPartitions", func(t *testing.T) {
|
||||
res, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowCollections,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
})
|
||||
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test ShowAllPartitions", func(t *testing.T) {
|
||||
res, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowCollections,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
})
|
||||
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test GetPartitionStates", func(t *testing.T) {
|
||||
res, err := queryCoord.GetPartitionStates(ctx, &querypb.GetPartitionStatesRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_GetPartitionStatistics,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
})
|
||||
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test LoadCollection", func(t *testing.T) {
|
||||
status, err := queryCoord.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
Schema: genCollectionSchema(defaultCollectionID, false),
|
||||
})
|
||||
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test ShowCollections", func(t *testing.T) {
|
||||
res, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowCollections,
|
||||
},
|
||||
CollectionIDs: []UniqueID{defaultCollectionID},
|
||||
})
|
||||
assert.Equal(t, 100, int(res.InMemoryPercentages[0]))
|
||||
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test ShowAllCollections", func(t *testing.T) {
|
||||
res, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ShowCollections,
|
||||
},
|
||||
})
|
||||
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test GetSegmentInfo", func(t *testing.T) {
|
||||
res, err := queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_SegmentInfo,
|
||||
},
|
||||
SegmentIDs: []UniqueID{defaultSegmentID},
|
||||
})
|
||||
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test ReleasePartition", func(t *testing.T) {
|
||||
status, err := queryCoord.ReleasePartitions(ctx, &querypb.ReleasePartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleasePartitions,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
})
|
||||
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Test ReleaseCollection", func(t *testing.T) {
|
||||
status, err := queryCoord.ReleaseCollection(ctx, &querypb.ReleaseCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleaseCollection,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
})
|
||||
assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test GetStatisticsChannel", func(t *testing.T) {
|
||||
_, err = queryCoord.GetStatisticsChannel(ctx)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test GetTimeTickChannel", func(t *testing.T) {
|
||||
_, err = queryCoord.GetTimeTickChannel(ctx)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test GetComponentStates", func(t *testing.T) {
|
||||
states, err := queryCoord.GetComponentStates(ctx)
|
||||
assert.Equal(t, states.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Equal(t, states.State.StateCode, internalpb.StateCode_Healthy)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test CreateQueryChannel", func(t *testing.T) {
|
||||
res, err := queryCoord.CreateQueryChannel(ctx, &querypb.CreateQueryChannelRequest{
|
||||
CollectionID: defaultCollectionID,
|
||||
})
|
||||
assert.Equal(t, res.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test GetMetrics", func(t *testing.T) {
|
||||
metricReq := make(map[string]string)
|
||||
metricReq[metricsinfo.MetricTypeKey] = "system_info"
|
||||
req, err := json.Marshal(metricReq)
|
||||
assert.Nil(t, err)
|
||||
res, err := queryCoord.GetMetrics(ctx, &milvuspb.GetMetricsRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
Request: string(req),
|
||||
})
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
|
||||
})
|
||||
|
||||
//nodes, err := queryCoord.cluster.getOnServiceNodes()
|
||||
//assert.Nil(t, err)
|
||||
|
||||
err = node.stop()
|
||||
//assert.Nil(t, err)
|
||||
|
||||
//allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
|
||||
//assert.Equal(t, allNodeOffline, true)
|
||||
queryCoord.Stop()
|
||||
}
|
||||
|
||||
func TestLoadBalanceTask(t *testing.T) {
|
||||
baseCtx := context.Background()
|
||||
|
||||
queryCoord, err := startQueryCoord(baseCtx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
queryNode1, err := startQueryNodeServer(baseCtx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
queryNode2, err := startQueryNodeServer(baseCtx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
res, err := queryCoord.LoadCollection(baseCtx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
Schema: genCollectionSchema(defaultCollectionID, false),
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, res.ErrorCode)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
for {
|
||||
collectionInfo := queryCoord.meta.showCollections()
|
||||
if collectionInfo[0].InMemoryPercentage == 100 {
|
||||
break
|
||||
}
|
||||
}
|
||||
nodeID := queryNode1.queryNodeID
|
||||
queryCoord.cluster.stopNode(nodeID)
|
||||
loadBalanceSegment := &querypb.LoadBalanceRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadBalanceSegments,
|
||||
SourceID: nodeID,
|
||||
},
|
||||
SourceNodeIDs: []int64{nodeID},
|
||||
BalanceReason: querypb.TriggerCondition_nodeDown,
|
||||
}
|
||||
|
||||
loadBalanceTask := &LoadBalanceTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: baseCtx,
|
||||
Condition: NewTaskCondition(baseCtx),
|
||||
triggerCondition: querypb.TriggerCondition_nodeDown,
|
||||
},
|
||||
LoadBalanceRequest: loadBalanceSegment,
|
||||
rootCoord: queryCoord.rootCoordClient,
|
||||
dataCoord: queryCoord.dataCoordClient,
|
||||
cluster: queryCoord.cluster,
|
||||
meta: queryCoord.meta,
|
||||
}
|
||||
queryCoord.scheduler.Enqueue([]task{loadBalanceTask})
|
||||
|
||||
res, err = queryCoord.ReleaseCollection(baseCtx, &querypb.ReleaseCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleaseCollection,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
queryNode1.stop()
|
||||
queryNode2.stop()
|
||||
queryCoord.Stop()
|
||||
}
|
|
@ -57,8 +57,6 @@ type Meta interface {
|
|||
|
||||
getPartitionStatesByID(collectionID UniqueID, partitionID UniqueID) (*querypb.PartitionStates, error)
|
||||
|
||||
hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error)
|
||||
getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error)
|
||||
getDmChannelsByNodeID(collectionID UniqueID, nodeID int64) ([]string, error)
|
||||
addDmChannel(collectionID UniqueID, nodeID int64, channels []string) error
|
||||
removeDmChannel(collectionID UniqueID, nodeID int64, channels []string) error
|
||||
|
@ -488,40 +486,6 @@ func (m *MetaReplica) releasePartition(collectionID UniqueID, partitionID Unique
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *MetaReplica) hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
if info, ok := m.collectionInfos[collectionID]; ok {
|
||||
channelInfos := info.ChannelInfos
|
||||
for _, channelInfo := range channelInfos {
|
||||
for _, channel := range channelInfo.ChannelIDs {
|
||||
if channel == channelID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, errors.New("hasWatchedDmChannel: can't find collection in collectionInfos")
|
||||
}
|
||||
|
||||
func (m *MetaReplica) getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
if info, ok := m.collectionInfos[collectionID]; ok {
|
||||
channels := make([]string, 0)
|
||||
for _, channelsInfo := range info.ChannelInfos {
|
||||
channels = append(channels, channelsInfo.ChannelIDs...)
|
||||
}
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("getDmChannelsByCollectionID: can't find collection in collectionInfos")
|
||||
}
|
||||
|
||||
func (m *MetaReplica) getDmChannelsByNodeID(collectionID UniqueID, nodeID int64) ([]string, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
|
|
@ -12,15 +12,19 @@
|
|||
package querycoord
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
func TestReplica_Release(t *testing.T) {
|
||||
refreshParams()
|
||||
etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
|
||||
assert.Nil(t, err)
|
||||
meta, err := newMeta(etcdKV)
|
||||
|
@ -48,3 +52,53 @@ func TestReplica_Release(t *testing.T) {
|
|||
assert.Equal(t, 0, len(collections))
|
||||
meta.releaseCollection(1)
|
||||
}
|
||||
|
||||
func TestReloadMetaFromKV(t *testing.T) {
|
||||
refreshParams()
|
||||
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
|
||||
assert.Nil(t, err)
|
||||
meta := &MetaReplica{
|
||||
client: kv,
|
||||
collectionInfos: map[UniqueID]*querypb.CollectionInfo{},
|
||||
segmentInfos: map[UniqueID]*querypb.SegmentInfo{},
|
||||
queryChannelInfos: map[UniqueID]*querypb.QueryChannelInfo{},
|
||||
}
|
||||
|
||||
kvs := make(map[string]string)
|
||||
collectionInfo := &querypb.CollectionInfo{
|
||||
CollectionID: defaultCollectionID,
|
||||
}
|
||||
collectionBlobs := proto.MarshalTextString(collectionInfo)
|
||||
collectionKey := fmt.Sprintf("%s/%d", collectionMetaPrefix, defaultCollectionID)
|
||||
kvs[collectionKey] = collectionBlobs
|
||||
|
||||
segmentInfo := &querypb.SegmentInfo{
|
||||
SegmentID: defaultSegmentID,
|
||||
}
|
||||
segmentBlobs := proto.MarshalTextString(segmentInfo)
|
||||
segmentKey := fmt.Sprintf("%s/%d", segmentMetaPrefix, defaultSegmentID)
|
||||
kvs[segmentKey] = segmentBlobs
|
||||
|
||||
queryChannelInfo := &querypb.QueryChannelInfo{
|
||||
CollectionID: defaultCollectionID,
|
||||
}
|
||||
queryChannelBlobs := proto.MarshalTextString(queryChannelInfo)
|
||||
queryChannelKey := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, defaultCollectionID)
|
||||
kvs[queryChannelKey] = queryChannelBlobs
|
||||
|
||||
err = kv.MultiSave(kvs)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = meta.reloadFromKV()
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 1, len(meta.collectionInfos))
|
||||
assert.Equal(t, 1, len(meta.segmentInfos))
|
||||
assert.Equal(t, 1, len(meta.queryChannelInfos))
|
||||
_, ok := meta.collectionInfos[defaultCollectionID]
|
||||
assert.Equal(t, true, ok)
|
||||
_, ok = meta.segmentInfos[defaultSegmentID]
|
||||
assert.Equal(t, true, ok)
|
||||
_, ok = meta.queryChannelInfos[defaultCollectionID]
|
||||
assert.Equal(t, true, ok)
|
||||
}
|
||||
|
|
|
@ -16,17 +16,13 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/kv"
|
||||
minioKV "github.com/milvus-io/milvus/internal/kv/minio"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||
|
@ -34,13 +30,10 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
qn "github.com/milvus-io/milvus/internal/querynode"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -373,163 +366,3 @@ func newIndexCoordMock() *indexCoordMock {
|
|||
func (c *indexCoordMock) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFilePathsRequest) (*indexpb.GetIndexFilePathsResponse, error) {
|
||||
return nil, errors.New("get index file path fail")
|
||||
}
|
||||
|
||||
type queryNodeServerMock struct {
|
||||
querypb.QueryNodeServer
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
queryNode *qn.QueryNode
|
||||
grpcErrChan chan error
|
||||
grpcServer *grpc.Server
|
||||
|
||||
addQueryChannels func() (*commonpb.Status, error)
|
||||
watchDmChannels func() (*commonpb.Status, error)
|
||||
loadSegment func() (*commonpb.Status, error)
|
||||
releaseCollection func() (*commonpb.Status, error)
|
||||
releasePartition func() (*commonpb.Status, error)
|
||||
releaseSegment func() (*commonpb.Status, error)
|
||||
}
|
||||
|
||||
func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
factory := msgstream.NewPmsFactory()
|
||||
|
||||
return &queryNodeServerMock{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
queryNode: qn.NewQueryNode(ctx1, factory),
|
||||
grpcErrChan: make(chan error),
|
||||
|
||||
addQueryChannels: returnSuccessResult,
|
||||
watchDmChannels: returnSuccessResult,
|
||||
loadSegment: returnSuccessResult,
|
||||
releaseCollection: returnSuccessResult,
|
||||
releasePartition: returnSuccessResult,
|
||||
releaseSegment: returnSuccessResult,
|
||||
}
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) init() error {
|
||||
qn.Params.Init()
|
||||
qn.Params.MetaRootPath = Params.MetaRootPath
|
||||
qn.Params.QueryNodeIP = funcutil.GetLocalIP()
|
||||
grpcPort := Params.Port
|
||||
|
||||
go func() {
|
||||
var lis net.Listener
|
||||
var err error
|
||||
|
||||
err = retry.Do(qs.ctx, func() error {
|
||||
addr := ":" + strconv.Itoa(grpcPort)
|
||||
lis, err = net.Listen("tcp", addr)
|
||||
if err == nil {
|
||||
qn.Params.QueryNodePort = int64(lis.Addr().(*net.TCPAddr).Port)
|
||||
} else {
|
||||
// set port=0 to get next available port
|
||||
grpcPort = 0
|
||||
}
|
||||
return err
|
||||
}, retry.Attempts(10))
|
||||
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
}
|
||||
|
||||
qs.grpcServer = grpc.NewServer()
|
||||
querypb.RegisterQueryNodeServer(qs.grpcServer, qs)
|
||||
if err = qs.grpcServer.Serve(lis); err != nil {
|
||||
log.Error(err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
rootCoord := newRootCoordMock()
|
||||
indexCoord := newIndexCoordMock()
|
||||
qs.queryNode.SetRootCoord(rootCoord)
|
||||
qs.queryNode.SetIndexCoord(indexCoord)
|
||||
err := qs.queryNode.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = qs.queryNode.Register(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) start() error {
|
||||
return qs.queryNode.Start()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) stop() error {
|
||||
qs.cancel()
|
||||
if qs.grpcServer != nil {
|
||||
qs.grpcServer.GracefulStop()
|
||||
}
|
||||
|
||||
err := qs.queryNode.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) run() error {
|
||||
if err := qs.init(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := qs.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
|
||||
return qs.addQueryChannels()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
return qs.watchDmChannels()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
|
||||
return qs.loadSegment()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
return qs.releaseCollection()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
return qs.releasePartition()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||
return qs.releaseSegment()
|
||||
}
|
||||
|
||||
func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) {
|
||||
node := newQueryNodeServerMock(ctx)
|
||||
err := node.run()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func returnSuccessResult() (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func returnFailedResult() (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
}, errors.New("query node do task failed")
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
package querycoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
type queryNodeClientMock struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
grpcClient querypb.QueryNodeClient
|
||||
conn *grpc.ClientConn
|
||||
|
||||
addr string
|
||||
}
|
||||
|
||||
func newQueryNodeTest(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error) {
|
||||
collectionInfo := make(map[UniqueID]*querypb.CollectionInfo)
|
||||
watchedChannels := make(map[UniqueID]*querypb.QueryChannelInfo)
|
||||
childCtx, cancel := context.WithCancel(ctx)
|
||||
client, err := newQueryNodeClientMock(childCtx, address)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
node := &queryNode{
|
||||
ctx: childCtx,
|
||||
cancel: cancel,
|
||||
id: id,
|
||||
address: address,
|
||||
client: client,
|
||||
kvClient: kv,
|
||||
collectionInfos: collectionInfo,
|
||||
watchedQueryChannels: watchedChannels,
|
||||
onService: false,
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func newQueryNodeClientMock(ctx context.Context, addr string) (*queryNodeClientMock, error) {
|
||||
if addr == "" {
|
||||
return nil, fmt.Errorf("addr is empty")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &queryNodeClientMock{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
addr: addr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) Init() error {
|
||||
ctx, cancel := context.WithTimeout(client.ctx, time.Second*2)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(ctx, client.addr, grpc.WithInsecure(), grpc.WithBlock())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.conn = conn
|
||||
log.Debug("QueryNodeClient try connect success")
|
||||
client.grpcClient = querypb.NewQueryNodeClient(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) Stop() error {
|
||||
client.cancel()
|
||||
return client.conn.Close()
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) Register() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
return client.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{})
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return client.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{})
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return client.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{})
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
|
||||
return client.grpcClient.AddQueryChannel(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQueryChannelRequest) (*commonpb.Status, error) {
|
||||
return client.grpcClient.RemoveQueryChannel(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
return client.grpcClient.WatchDmChannels(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
|
||||
return client.grpcClient.LoadSegments(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
return client.grpcClient.ReleaseCollection(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
return client.grpcClient.ReleasePartitions(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||
return client.grpcClient.ReleaseSegments(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
|
||||
return client.grpcClient.GetSegmentInfo(ctx, req)
|
||||
}
|
||||
|
||||
func (client *queryNodeClientMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
|
||||
return client.grpcClient.GetMetrics(ctx, req)
|
||||
}
|
|
@ -0,0 +1,196 @@
|
|||
package querycoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type queryNodeServerMock struct {
|
||||
querypb.QueryNodeServer
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
session *sessionutil.Session
|
||||
grpcErrChan chan error
|
||||
grpcServer *grpc.Server
|
||||
|
||||
queryNodeIP string
|
||||
queryNodePort int64
|
||||
queryNodeID int64
|
||||
|
||||
addQueryChannels func() (*commonpb.Status, error)
|
||||
watchDmChannels func() (*commonpb.Status, error)
|
||||
loadSegment func() (*commonpb.Status, error)
|
||||
releaseCollection func() (*commonpb.Status, error)
|
||||
releasePartition func() (*commonpb.Status, error)
|
||||
releaseSegment func() (*commonpb.Status, error)
|
||||
}
|
||||
|
||||
func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
return &queryNodeServerMock{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
grpcErrChan: make(chan error),
|
||||
|
||||
addQueryChannels: returnSuccessResult,
|
||||
watchDmChannels: returnSuccessResult,
|
||||
loadSegment: returnSuccessResult,
|
||||
releaseCollection: returnSuccessResult,
|
||||
releasePartition: returnSuccessResult,
|
||||
releaseSegment: returnSuccessResult,
|
||||
}
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) Register() error {
|
||||
log.Debug("query node session info", zap.String("metaPath", Params.MetaRootPath), zap.Strings("etcdEndPoints", Params.EtcdEndpoints))
|
||||
qs.session = sessionutil.NewSession(qs.ctx, Params.MetaRootPath, Params.EtcdEndpoints)
|
||||
qs.session.Init(typeutil.QueryNodeRole, qs.queryNodeIP+":"+strconv.FormatInt(qs.queryNodePort, 10), false)
|
||||
qs.queryNodeID = qs.session.ServerID
|
||||
log.Debug("query nodeID", zap.Int64("nodeID", qs.queryNodeID))
|
||||
log.Debug("query node address", zap.String("address", qs.session.Address))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) init() error {
|
||||
qs.queryNodeIP = funcutil.GetLocalIP()
|
||||
grpcPort := Params.Port
|
||||
|
||||
go func() {
|
||||
var lis net.Listener
|
||||
var err error
|
||||
err = retry.Do(qs.ctx, func() error {
|
||||
addr := ":" + strconv.Itoa(grpcPort)
|
||||
lis, err = net.Listen("tcp", addr)
|
||||
if err == nil {
|
||||
qs.queryNodePort = int64(lis.Addr().(*net.TCPAddr).Port)
|
||||
} else {
|
||||
// set port=0 to get next available port
|
||||
grpcPort = 0
|
||||
}
|
||||
return err
|
||||
}, retry.Attempts(10))
|
||||
if err != nil {
|
||||
qs.grpcErrChan <- err
|
||||
}
|
||||
|
||||
qs.grpcServer = grpc.NewServer()
|
||||
querypb.RegisterQueryNodeServer(qs.grpcServer, qs)
|
||||
go funcutil.CheckGrpcReady(qs.ctx, qs.grpcErrChan)
|
||||
if err = qs.grpcServer.Serve(lis); err != nil {
|
||||
qs.grpcErrChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
err := <-qs.grpcErrChan
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := qs.Register(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) stop() error {
|
||||
qs.cancel()
|
||||
if qs.grpcServer != nil {
|
||||
qs.grpcServer.GracefulStop()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) run() error {
|
||||
if err := qs.init(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := qs.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
|
||||
return qs.addQueryChannels()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
|
||||
return qs.watchDmChannels()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
|
||||
return qs.loadSegment()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
return qs.releaseCollection()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
return qs.releasePartition()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
|
||||
return qs.releaseSegment()
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) GetSegmentInfo(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
|
||||
return &querypb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
|
||||
return &milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func startQueryNodeServer(ctx context.Context) (*queryNodeServerMock, error) {
|
||||
node := newQueryNodeServerMock(ctx)
|
||||
err := node.run()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func returnSuccessResult() (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func returnFailedResult() (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
}, errors.New("query node do task failed")
|
||||
}
|
|
@ -52,6 +52,7 @@ type QueryCoord struct {
|
|||
queryCoordID uint64
|
||||
meta Meta
|
||||
cluster *queryNodeCluster
|
||||
newNodeFn newQueryNodeFn
|
||||
scheduler *TaskScheduler
|
||||
|
||||
dataCoordClient types.DataCoord
|
||||
|
@ -98,7 +99,7 @@ func (qc *QueryCoord) Init() error {
|
|||
return err
|
||||
}
|
||||
|
||||
qc.cluster, err = newQueryNodeCluster(qc.meta, qc.kvClient)
|
||||
qc.cluster, err = newQueryNodeCluster(qc.meta, qc.kvClient, qc.newNodeFn)
|
||||
if err != nil {
|
||||
log.Error("query coordinator init cluster failed", zap.Error(err))
|
||||
return err
|
||||
|
@ -160,6 +161,7 @@ func NewQueryCoord(ctx context.Context, factory msgstream.Factory) (*QueryCoord,
|
|||
loopCtx: ctx1,
|
||||
loopCancel: cancel,
|
||||
msFactory: factory,
|
||||
newNodeFn: newQueryNode,
|
||||
}
|
||||
|
||||
service.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
|
|
|
@ -14,91 +14,61 @@ package querycoord
|
|||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
)
|
||||
|
||||
func setup() {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func refreshParams() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63(), 10)
|
||||
Params.StatsChannelName = Params.StatsChannelName + suffix
|
||||
Params.TimeTickChannelName = Params.TimeTickChannelName + suffix
|
||||
Params.MetaRootPath = Params.MetaRootPath + suffix
|
||||
}
|
||||
|
||||
func refreshChannelNames() {
|
||||
suffix := "-test-query-Coord" + strconv.FormatInt(rand.Int63n(1000000), 10)
|
||||
Params.StatsChannelName = Params.StatsChannelName + suffix
|
||||
Params.TimeTickChannelName = Params.TimeTickChannelName + suffix
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
/*
|
||||
setup()
|
||||
//refreshChannelNames()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
*/
|
||||
setup()
|
||||
//refreshChannelNames()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestQueryCoord_Init(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
msFactory := msgstream.NewPmsFactory()
|
||||
service, err := NewQueryCoord(context.Background(), msFactory)
|
||||
assert.Nil(t, err)
|
||||
service.Register()
|
||||
service.Init()
|
||||
service.Start()
|
||||
func NewQueryCoordTest(ctx context.Context, factory msgstream.Factory) (*QueryCoord, error) {
|
||||
refreshParams()
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
queryChannels := make([]*queryChannelInfo, 0)
|
||||
channelID := len(queryChannels)
|
||||
searchPrefix := Params.SearchChannelPrefix
|
||||
searchResultPrefix := Params.SearchResultChannelPrefix
|
||||
allocatedQueryChannel := searchPrefix + "-" + strconv.FormatInt(int64(channelID), 10)
|
||||
allocatedQueryResultChannel := searchResultPrefix + "-" + strconv.FormatInt(int64(channelID), 10)
|
||||
|
||||
t.Run("Test Get statistics channel", func(t *testing.T) {
|
||||
response, err := service.GetStatisticsChannel(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, response.Value, "query-node-stats")
|
||||
queryChannels = append(queryChannels, &queryChannelInfo{
|
||||
requestChannel: allocatedQueryChannel,
|
||||
responseChannel: allocatedQueryResultChannel,
|
||||
})
|
||||
|
||||
t.Run("Test Get timeTick channel", func(t *testing.T) {
|
||||
response, err := service.GetTimeTickChannel(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, response.Value, "queryTimeTick")
|
||||
})
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
service := &QueryCoord{
|
||||
loopCtx: ctx1,
|
||||
loopCancel: cancel,
|
||||
msFactory: factory,
|
||||
newNodeFn: newQueryNodeTest,
|
||||
}
|
||||
|
||||
service.Stop()
|
||||
service.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
log.Debug("query coordinator", zap.Any("queryChannels", queryChannels))
|
||||
return service, nil
|
||||
}
|
||||
|
||||
//func TestQueryCoord_load(t *testing.T) {
|
||||
// ctx := context.Background()
|
||||
// msFactory := msgstream.NewPmsFactory()
|
||||
// service, err := NewQueryCoord(context.Background(), msFactory)
|
||||
// assert.Nil(t, err)
|
||||
// service.Init()
|
||||
// service.Start()
|
||||
// service.SetRootCoord(newRootCoordMock())
|
||||
// service.SetDataCoord(NewDataMock())
|
||||
// registerNodeRequest := &querypb.RegisterNodeRequest{
|
||||
// Address: &commonpb.Address{},
|
||||
// }
|
||||
// service.RegisterNode(ctx, registerNodeRequest)
|
||||
//
|
||||
// t.Run("Test LoadSegment", func(t *testing.T) {
|
||||
// loadCollectionRequest := &querypb.LoadCollectionRequest{
|
||||
// CollectionID: 1,
|
||||
// }
|
||||
// response, err := service.LoadCollection(ctx, loadCollectionRequest)
|
||||
// assert.Nil(t, err)
|
||||
// assert.Equal(t, response.ErrorCode, commonpb.ErrorCode_Success)
|
||||
// })
|
||||
//
|
||||
// t.Run("Test LoadPartition", func(t *testing.T) {
|
||||
// loadPartitionRequest := &querypb.LoadPartitionsRequest{
|
||||
// CollectionID: 1,
|
||||
// PartitionIDs: []UniqueID{1},
|
||||
// }
|
||||
// response, err := service.LoadPartitions(ctx, loadPartitionRequest)
|
||||
// assert.Nil(t, err)
|
||||
// assert.Equal(t, response.ErrorCode, commonpb.ErrorCode_Success)
|
||||
// })
|
||||
//}
|
||||
|
|
|
@ -37,19 +37,14 @@ type Node interface {
|
|||
stop()
|
||||
clearNodeInfo() error
|
||||
|
||||
hasCollection(collectionID UniqueID) bool
|
||||
addCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) error
|
||||
setCollectionInfo(info *querypb.CollectionInfo) error
|
||||
getCollectionInfoByID(collectionID UniqueID) (*querypb.CollectionInfo, error)
|
||||
showCollections() []*querypb.CollectionInfo
|
||||
releaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) error
|
||||
|
||||
hasPartition(collectionID UniqueID, partitionID UniqueID) bool
|
||||
addPartition(collectionID UniqueID, partitionID UniqueID) error
|
||||
releasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) error
|
||||
|
||||
hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error)
|
||||
getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error)
|
||||
watchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) error
|
||||
removeDmChannel(collectionID UniqueID, channels []string) error
|
||||
|
||||
|
@ -84,37 +79,38 @@ type queryNode struct {
|
|||
serviceLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) Node {
|
||||
func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error) {
|
||||
collectionInfo := make(map[UniqueID]*querypb.CollectionInfo)
|
||||
watchedChannels := make(map[UniqueID]*querypb.QueryChannelInfo)
|
||||
childCtx, cancel := context.WithCancel(ctx)
|
||||
client, err := nodeclient.NewClient(childCtx, address)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
node := &queryNode{
|
||||
ctx: childCtx,
|
||||
cancel: cancel,
|
||||
id: id,
|
||||
address: address,
|
||||
client: client,
|
||||
kvClient: kv,
|
||||
collectionInfos: collectionInfo,
|
||||
watchedQueryChannels: watchedChannels,
|
||||
onService: false,
|
||||
}
|
||||
|
||||
return node
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func (qn *queryNode) start() error {
|
||||
client, err := nodeclient.NewClient(qn.ctx, qn.address)
|
||||
if err != nil {
|
||||
if err := qn.client.Init(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = client.Init(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = client.Start(); err != nil {
|
||||
if err := qn.client.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
qn.client = client
|
||||
qn.serviceLock.Lock()
|
||||
qn.onService = true
|
||||
qn.serviceLock.Unlock()
|
||||
|
@ -132,32 +128,6 @@ func (qn *queryNode) stop() {
|
|||
qn.cancel()
|
||||
}
|
||||
|
||||
func (qn *queryNode) hasCollection(collectionID UniqueID) bool {
|
||||
qn.RLock()
|
||||
defer qn.RUnlock()
|
||||
|
||||
if _, ok := qn.collectionInfos[collectionID]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (qn *queryNode) hasPartition(collectionID UniqueID, partitionID UniqueID) bool {
|
||||
qn.RLock()
|
||||
defer qn.RUnlock()
|
||||
|
||||
if info, ok := qn.collectionInfos[collectionID]; ok {
|
||||
for _, id := range info.PartitionIDs {
|
||||
if partitionID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (qn *queryNode) addCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) error {
|
||||
qn.Lock()
|
||||
defer qn.Unlock()
|
||||
|
@ -195,16 +165,6 @@ func (qn *queryNode) setCollectionInfo(info *querypb.CollectionInfo) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (qn *queryNode) getCollectionInfoByID(collectionID UniqueID) (*querypb.CollectionInfo, error) {
|
||||
qn.Lock()
|
||||
defer qn.Lock()
|
||||
|
||||
if _, ok := qn.collectionInfos[collectionID]; ok {
|
||||
return proto.Clone(qn.collectionInfos[collectionID]).(*querypb.CollectionInfo), nil
|
||||
}
|
||||
return nil, errors.New("GetCollectionInfoByID: can't find collection")
|
||||
}
|
||||
|
||||
func (qn *queryNode) showCollections() []*querypb.CollectionInfo {
|
||||
qn.RLock()
|
||||
defer qn.RUnlock()
|
||||
|
@ -281,40 +241,6 @@ func (qn *queryNode) releasePartitionsInfo(collectionID UniqueID, partitionIDs [
|
|||
return nil
|
||||
}
|
||||
|
||||
func (qn *queryNode) hasWatchedDmChannel(collectionID UniqueID, channelID string) (bool, error) {
|
||||
qn.RLock()
|
||||
defer qn.RUnlock()
|
||||
|
||||
if info, ok := qn.collectionInfos[collectionID]; ok {
|
||||
channelInfos := info.ChannelInfos
|
||||
for _, channelInfo := range channelInfos {
|
||||
for _, channel := range channelInfo.ChannelIDs {
|
||||
if channel == channelID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, errors.New("HasWatchedDmChannel: can't find collection in collectionInfos")
|
||||
}
|
||||
|
||||
func (qn *queryNode) getDmChannelsByCollectionID(collectionID UniqueID) ([]string, error) {
|
||||
qn.RLock()
|
||||
defer qn.RUnlock()
|
||||
|
||||
if info, ok := qn.collectionInfos[collectionID]; ok {
|
||||
channels := make([]string, 0)
|
||||
for _, channelsInfo := range info.ChannelInfos {
|
||||
channels = append(channels, channelsInfo.ChannelIDs...)
|
||||
}
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("GetDmChannelsByCollectionID: can't find collection in collectionInfos")
|
||||
}
|
||||
|
||||
func (qn *queryNode) addDmChannel(collectionID UniqueID, channels []string) error {
|
||||
qn.Lock()
|
||||
defer qn.Unlock()
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
func startQueryCoord(ctx context.Context) (*QueryCoord, error) {
|
||||
factory := msgstream.NewPmsFactory()
|
||||
|
||||
coord, err := NewQueryCoord(ctx, factory)
|
||||
coord, err := NewQueryCoordTest(ctx, factory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -59,6 +59,31 @@ func startQueryCoord(ctx context.Context) (*QueryCoord, error) {
|
|||
return coord, nil
|
||||
}
|
||||
|
||||
//func waitQueryNodeOnline(cluster *queryNodeCluster, nodeID int64)
|
||||
|
||||
func waitAllQueryNodeOffline(cluster *queryNodeCluster, nodes map[int64]Node) bool {
|
||||
reDoCount := 20
|
||||
for {
|
||||
if reDoCount <= 0 {
|
||||
return false
|
||||
}
|
||||
allOffline := true
|
||||
for nodeID := range nodes {
|
||||
_, err := cluster.getNodeByID(nodeID)
|
||||
if err == nil {
|
||||
allOffline = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allOffline {
|
||||
return true
|
||||
}
|
||||
log.Debug("wait all queryNode offline")
|
||||
time.Sleep(time.Second)
|
||||
reDoCount--
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryNode_MultiNode_stop(t *testing.T) {
|
||||
baseCtx := context.Background()
|
||||
|
||||
|
@ -68,23 +93,11 @@ func TestQueryNode_MultiNode_stop(t *testing.T) {
|
|||
queryNode1, err := startQueryNodeServer(baseCtx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
//queryNode2, err := startQueryNodeServer(baseCtx)
|
||||
//assert.Nil(t, err)
|
||||
|
||||
//queryNode3, err := startQueryNodeServer(baseCtx)
|
||||
//assert.Nil(t, err)
|
||||
|
||||
//queryNode4, err := startQueryNodeServer(baseCtx)
|
||||
//assert.Nil(t, err)
|
||||
|
||||
queryNode5, err := startQueryNodeServer(baseCtx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
queryNode1.stop()
|
||||
//queryNode2.stop()
|
||||
//queryNode3.stop()
|
||||
//queryNode4.stop()
|
||||
|
||||
queryCoord.LoadCollection(baseCtx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
@ -106,21 +119,8 @@ func TestQueryNode_MultiNode_stop(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
queryNode5.stop()
|
||||
|
||||
for {
|
||||
allOffline := true
|
||||
for nodeID := range nodes {
|
||||
_, err = queryCoord.cluster.getNodeByID(nodeID)
|
||||
if err == nil {
|
||||
allOffline = false
|
||||
time.Sleep(time.Second)
|
||||
break
|
||||
}
|
||||
}
|
||||
if allOffline {
|
||||
break
|
||||
}
|
||||
log.Debug("wait all queryNode offline")
|
||||
}
|
||||
allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
|
||||
assert.Equal(t, allNodeOffline, true)
|
||||
queryCoord.Stop()
|
||||
}
|
||||
|
||||
|
@ -133,9 +133,6 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
|
|||
queryNode1, err := startQueryNodeServer(baseCtx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
//queryNode2, err := startQueryNodeServer(baseCtx)
|
||||
//assert.Nil(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
queryCoord.LoadCollection(baseCtx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
@ -145,13 +142,8 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
|
|||
Schema: genCollectionSchema(defaultCollectionID, false),
|
||||
})
|
||||
queryNode1.stop()
|
||||
//queryNode2.stop()
|
||||
queryNode3, err := startQueryNodeServer(baseCtx)
|
||||
assert.Nil(t, err)
|
||||
//queryNode4, err := startQueryNodeServer(baseCtx)
|
||||
//assert.Nil(t, err)
|
||||
//queryNode5, err := startQueryNodeServer(baseCtx)
|
||||
//assert.Nil(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
_, err = queryCoord.ReleaseCollection(baseCtx, &querypb.ReleaseCollectionRequest{
|
||||
|
@ -164,24 +156,9 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) {
|
|||
nodes, err := queryCoord.cluster.onServiceNodes()
|
||||
assert.Nil(t, err)
|
||||
queryNode3.stop()
|
||||
//queryNode4.stop()
|
||||
//queryNode5.stop()
|
||||
|
||||
for {
|
||||
allOffline := true
|
||||
for nodeID := range nodes {
|
||||
_, err = queryCoord.cluster.getNodeByID(nodeID)
|
||||
if err == nil {
|
||||
allOffline = false
|
||||
time.Sleep(time.Second)
|
||||
break
|
||||
}
|
||||
}
|
||||
if allOffline {
|
||||
break
|
||||
}
|
||||
log.Debug("wait all queryNode offline")
|
||||
}
|
||||
allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
|
||||
assert.Equal(t, allNodeOffline, true)
|
||||
queryCoord.Stop()
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ package querycoord
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -142,13 +144,8 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
|
|||
time.Sleep(time.Second)
|
||||
queryNode.stop()
|
||||
|
||||
for {
|
||||
_, err = queryCoord.cluster.getNodeByID(nodeID)
|
||||
if err == nil {
|
||||
time.Sleep(time.Second)
|
||||
break
|
||||
}
|
||||
}
|
||||
allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
|
||||
assert.Equal(t, allNodeOffline, true)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
newActiveTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix)
|
||||
|
@ -157,27 +154,247 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) {
|
|||
queryCoord.Stop()
|
||||
}
|
||||
|
||||
func TestUnMarshalTask_LoadCollection(t *testing.T) {
|
||||
func TestUnMarshalTask(t *testing.T) {
|
||||
refreshParams()
|
||||
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
|
||||
assert.Nil(t, err)
|
||||
taskScheduler := &TaskScheduler{}
|
||||
|
||||
loadTask := &LoadCollectionTask{
|
||||
t.Run("Test LoadCollectionTask", func(t *testing.T) {
|
||||
loadTask := &LoadCollectionTask{
|
||||
LoadCollectionRequest: &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := loadTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalLoadCollection", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalLoadCollection")
|
||||
value, err := kv.Load("testMarshalLoadCollection")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_LoadCollection)
|
||||
})
|
||||
|
||||
t.Run("Test LoadPartitionsTask", func(t *testing.T) {
|
||||
loadTask := &LoadPartitionTask{
|
||||
LoadPartitionsRequest: &querypb.LoadPartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadPartitions,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := loadTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalLoadPartition", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalLoadPartition")
|
||||
value, err := kv.Load("testMarshalLoadPartition")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_LoadPartitions)
|
||||
})
|
||||
|
||||
t.Run("Test ReleaseCollectionTask", func(t *testing.T) {
|
||||
releaseTask := &ReleaseCollectionTask{
|
||||
ReleaseCollectionRequest: &querypb.ReleaseCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleaseCollection,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := releaseTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalReleaseCollection", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalReleaseCollection")
|
||||
value, err := kv.Load("testMarshalReleaseCollection")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseCollection)
|
||||
})
|
||||
|
||||
t.Run("Test ReleasePartitionTask", func(t *testing.T) {
|
||||
releaseTask := &ReleasePartitionTask{
|
||||
ReleasePartitionsRequest: &querypb.ReleasePartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleasePartitions,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := releaseTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalReleasePartition", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalReleasePartition")
|
||||
value, err := kv.Load("testMarshalReleasePartition")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_ReleasePartitions)
|
||||
})
|
||||
|
||||
t.Run("Test LoadSegmentTask", func(t *testing.T) {
|
||||
loadTask := &LoadSegmentTask{
|
||||
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadSegments,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := loadTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalLoadSegment", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalLoadSegment")
|
||||
value, err := kv.Load("testMarshalLoadSegment")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_LoadSegments)
|
||||
})
|
||||
|
||||
t.Run("Test ReleaseSegmentTask", func(t *testing.T) {
|
||||
releaseTask := &ReleaseSegmentTask{
|
||||
ReleaseSegmentsRequest: &querypb.ReleaseSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleaseSegments,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := releaseTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalReleaseSegment", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalReleaseSegment")
|
||||
value, err := kv.Load("testMarshalReleaseSegment")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_ReleaseSegments)
|
||||
})
|
||||
|
||||
t.Run("Test WatchDmChannelTask", func(t *testing.T) {
|
||||
watchTask := &WatchDmChannelTask{
|
||||
WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_WatchDmChannels,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := watchTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalWatchDmChannel", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalWatchDmChannel")
|
||||
value, err := kv.Load("testMarshalWatchDmChannel")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_WatchDmChannels)
|
||||
})
|
||||
|
||||
t.Run("Test WatchQueryChannelTask", func(t *testing.T) {
|
||||
watchTask := &WatchQueryChannelTask{
|
||||
AddQueryChannelRequest: &querypb.AddQueryChannelRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_WatchQueryChannels,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := watchTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalWatchQueryChannel", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalWatchQueryChannel")
|
||||
value, err := kv.Load("testMarshalWatchQueryChannel")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_WatchQueryChannels)
|
||||
})
|
||||
|
||||
t.Run("Test LoadBalanceTask", func(t *testing.T) {
|
||||
loadBalanceTask := &LoadBalanceTask{
|
||||
LoadBalanceRequest: &querypb.LoadBalanceRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadBalanceSegments,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
blobs, err := loadBalanceTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalLoadBalanceTask", string(blobs))
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalLoadBalanceTask")
|
||||
value, err := kv.Load("testMarshalLoadBalanceTask")
|
||||
assert.Nil(t, err)
|
||||
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_LoadBalanceSegments)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReloadTaskFromKV(t *testing.T) {
|
||||
refreshParams()
|
||||
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
|
||||
assert.Nil(t, err)
|
||||
taskScheduler := &TaskScheduler{
|
||||
client: kv,
|
||||
triggerTaskQueue: NewTaskQueue(),
|
||||
}
|
||||
|
||||
kvs := make(map[string]string)
|
||||
triggerTask := &LoadCollectionTask{
|
||||
LoadCollectionRequest: &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
Timestamp: 1,
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
},
|
||||
},
|
||||
}
|
||||
blobs, err := loadTask.Marshal()
|
||||
triggerBlobs, err := triggerTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
err = kv.Save("testMarshalLoadCollection", string(blobs))
|
||||
triggerTaskKey := fmt.Sprintf("%s/%d", triggerTaskPrefix, 100)
|
||||
kvs[triggerTaskKey] = string(triggerBlobs)
|
||||
|
||||
activeTask := &LoadSegmentTask{
|
||||
LoadSegmentsRequest: &querypb.LoadSegmentsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
Timestamp: 2,
|
||||
MsgType: commonpb.MsgType_LoadSegments,
|
||||
},
|
||||
},
|
||||
}
|
||||
activeBlobs, err := activeTask.Marshal()
|
||||
assert.Nil(t, err)
|
||||
defer kv.RemoveWithPrefix("testMarshalLoadCollection")
|
||||
value, err := kv.Load("testMarshalLoadCollection")
|
||||
activeTaskKey := fmt.Sprintf("%s/%d", activeTaskPrefix, 101)
|
||||
kvs[activeTaskKey] = string(activeBlobs)
|
||||
|
||||
stateKey := fmt.Sprintf("%s/%d", taskInfoPrefix, 100)
|
||||
kvs[stateKey] = strconv.Itoa(int(taskDone))
|
||||
err = kv.MultiSave(kvs)
|
||||
assert.Nil(t, err)
|
||||
|
||||
taskScheduler := &TaskScheduler{}
|
||||
task, err := taskScheduler.unmarshalTask(value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, task.Type(), commonpb.MsgType_LoadCollection)
|
||||
taskScheduler.reloadFromKV()
|
||||
|
||||
task := taskScheduler.triggerTaskQueue.PopTask()
|
||||
assert.Equal(t, taskDone, task.State())
|
||||
assert.Equal(t, 1, len(task.GetChildTask()))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
package querycoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
func TestTriggerTask(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
queryCoord, err := startQueryCoord(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
node, err := startQueryNodeServer(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
t.Run("Test LoadCollection", func(t *testing.T) {
|
||||
req := &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
Schema: genCollectionSchema(defaultCollectionID, false),
|
||||
}
|
||||
loadCollectionTask := &LoadCollectionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
triggerCondition: querypb.TriggerCondition_grpcRequest,
|
||||
},
|
||||
LoadCollectionRequest: req,
|
||||
rootCoord: queryCoord.rootCoordClient,
|
||||
dataCoord: queryCoord.dataCoordClient,
|
||||
cluster: queryCoord.cluster,
|
||||
meta: queryCoord.meta,
|
||||
}
|
||||
|
||||
err = queryCoord.scheduler.processTask(loadCollectionTask)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test ReleaseCollection", func(t *testing.T) {
|
||||
req := &querypb.ReleaseCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleaseCollection,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
}
|
||||
loadCollectionTask := &ReleaseCollectionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
triggerCondition: querypb.TriggerCondition_grpcRequest,
|
||||
},
|
||||
ReleaseCollectionRequest: req,
|
||||
rootCoord: queryCoord.rootCoordClient,
|
||||
cluster: queryCoord.cluster,
|
||||
meta: queryCoord.meta,
|
||||
}
|
||||
|
||||
err = queryCoord.scheduler.processTask(loadCollectionTask)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test LoadPartition", func(t *testing.T) {
|
||||
req := &querypb.LoadPartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadPartitions,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
}
|
||||
loadCollectionTask := &LoadPartitionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
triggerCondition: querypb.TriggerCondition_grpcRequest,
|
||||
},
|
||||
LoadPartitionsRequest: req,
|
||||
dataCoord: queryCoord.dataCoordClient,
|
||||
cluster: queryCoord.cluster,
|
||||
meta: queryCoord.meta,
|
||||
}
|
||||
|
||||
err = queryCoord.scheduler.processTask(loadCollectionTask)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Test ReleasePartition", func(t *testing.T) {
|
||||
req := &querypb.ReleasePartitionsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_ReleasePartitions,
|
||||
},
|
||||
CollectionID: defaultCollectionID,
|
||||
PartitionIDs: []UniqueID{defaultPartitionID},
|
||||
}
|
||||
loadCollectionTask := &ReleasePartitionTask{
|
||||
BaseTask: BaseTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
triggerCondition: querypb.TriggerCondition_grpcRequest,
|
||||
},
|
||||
ReleasePartitionsRequest: req,
|
||||
cluster: queryCoord.cluster,
|
||||
}
|
||||
|
||||
err = queryCoord.scheduler.processTask(loadCollectionTask)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
//nodes, err := queryCoord.cluster.getOnServiceNodes()
|
||||
//assert.Nil(t, err)
|
||||
|
||||
err = node.stop()
|
||||
//assert.Nil(t, err)
|
||||
|
||||
//allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes)
|
||||
//assert.Equal(t, allNodeOffline, true)
|
||||
queryCoord.Stop()
|
||||
}
|
Loading…
Reference in New Issue