diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index a11b417525..2d89779bde 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" ) const ( @@ -45,6 +46,7 @@ type Cluster interface { getNumSegments(nodeID int64) (int, error) watchDmChannels(ctx context.Context, nodeID int64, in *querypb.WatchDmChannelsRequest) error + //TODO:: removeDmChannel getNumDmChannels(nodeID int64) (int, error) hasWatchedQueryChannel(ctx context.Context, nodeID int64, collectionID UniqueID) bool @@ -55,31 +57,51 @@ type Cluster interface { releasePartitions(ctx context.Context, nodeID int64, in *querypb.ReleasePartitionsRequest) error getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) - registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID) error + registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error getNodeByID(nodeID int64) (Node, error) removeNodeInfo(nodeID int64) error stopNode(nodeID int64) - onServiceNodes() (map[int64]Node, error) - isOnService(nodeID int64) (bool, error) + onlineNodes() (map[int64]Node, error) + isOnline(nodeID int64) (bool, error) + offlineNodes() (map[int64]Node, error) - printMeta() + getSessionVersion() int64 + + getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) []queryNodeGetMetricsResponse } type newQueryNodeFn func(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error) +type nodeState int + +const ( + disConnect nodeState = 0 + online nodeState = 1 + offline nodeState = 2 +) + type queryNodeCluster struct { + ctx context.Context + cancel context.CancelFunc client *etcdkv.EtcdKV + session *sessionutil.Session + sessionVersion int64 + sync.RWMutex clusterMeta Meta nodes map[int64]Node newNodeFn newQueryNodeFn } -func newQueryNodeCluster(clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn) (*queryNodeCluster, error) { +func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn, session *sessionutil.Session) (*queryNodeCluster, error) { + childCtx, cancel := context.WithCancel(ctx) nodes := make(map[int64]Node) c := &queryNodeCluster{ + ctx: childCtx, + cancel: cancel, client: kv, + session: session, clusterMeta: clusterMeta, nodes: nodes, newNodeFn: newNodeFn, @@ -93,30 +115,55 @@ func newQueryNodeCluster(clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQuery } func (c *queryNodeCluster) reloadFromKV() error { - nodeIDs := make([]UniqueID, 0) - keys, values, err := c.client.LoadWithPrefix(queryNodeInfoPrefix) + toLoadMetaNodeIDs := make([]int64, 0) + // get current online session + onlineNodeSessions, version, _ := c.session.GetSessions(typeutil.QueryNodeRole) + onlineSessionMap := make(map[int64]*sessionutil.Session) + for _, session := range onlineNodeSessions { + nodeID := session.ServerID + onlineSessionMap[nodeID] = session + } + for nodeID, session := range onlineSessionMap { + log.Debug("ReloadFromKV: register a queryNode to cluster", zap.Any("nodeID", nodeID)) + err := c.registerNode(c.ctx, session, nodeID, disConnect) + if err != nil { + log.Error("query node failed to register", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error())) + return err + } + toLoadMetaNodeIDs = append(toLoadMetaNodeIDs, nodeID) + } + c.sessionVersion = version + + // load node information before power off from etcd + oldStringNodeIDs, oldNodeSessions, err := c.client.LoadWithPrefix(queryNodeInfoPrefix) if err != nil { + log.Error("reloadFromKV: get previous node info from etcd error", zap.Error(err)) return err } - for index := range keys { - nodeID, err := strconv.ParseInt(filepath.Base(keys[index]), 10, 64) + for index := range oldStringNodeIDs { + nodeID, err := strconv.ParseInt(filepath.Base(oldStringNodeIDs[index]), 10, 64) if err != nil { + log.Error("WatchNodeLoop: parse nodeID error", zap.Error(err)) return err } - - session := &sessionutil.Session{} - err = json.Unmarshal([]byte(values[index]), session) - if err != nil { - return err + if _, ok := onlineSessionMap[nodeID]; !ok { + session := &sessionutil.Session{} + err = json.Unmarshal([]byte(oldNodeSessions[index]), session) + if err != nil { + log.Error("WatchNodeLoop: unmarshal session error", zap.Error(err)) + return err + } + err = c.registerNode(context.Background(), session, nodeID, offline) + if err != nil { + log.Debug("ReloadFromKV: failed to add queryNode to cluster", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error())) + return err + } + toLoadMetaNodeIDs = append(toLoadMetaNodeIDs, nodeID) } - err = c.registerNode(context.Background(), session, nodeID) - if err != nil { - log.Debug("ReloadFromKV: failed to add queryNode to cluster", zap.Int64("nodeID", nodeID), zap.String("error info", err.Error())) - continue - } - nodeIDs = append(nodeIDs, nodeID) } - for _, nodeID := range nodeIDs { + + // load collection meta of queryNode from etcd + for _, nodeID := range toLoadMetaNodeIDs { infoPrefix := fmt.Sprintf("%s/%d", queryNodeMetaPrefix, nodeID) _, collectionValues, err := c.client.LoadWithPrefix(infoPrefix) if err != nil { @@ -138,11 +185,15 @@ func (c *queryNodeCluster) reloadFromKV() error { return nil } +func (c *queryNodeCluster) getSessionVersion() int64 { + return c.sessionVersion +} + func (c *queryNodeCluster) getComponentInfos(ctx context.Context) ([]*internalpb.ComponentInfo, error) { c.RLock() defer c.RUnlock() subComponentInfos := make([]*internalpb.ComponentInfo, 0) - nodes, err := c.getOnServiceNodes() + nodes, err := c.getOnlineNodes() if err != nil { log.Debug("GetComponentInfos: failed get on service nodes", zap.String("error info", err.Error())) return nil, err @@ -204,7 +255,7 @@ func (c *queryNodeCluster) releaseSegments(ctx context.Context, nodeID int64, in defer c.Unlock() if node, ok := c.nodes[nodeID]; ok { - if !node.isOnService() { + if !node.isOnline() { return errors.New("node offline") } @@ -416,7 +467,7 @@ func (c *queryNodeCluster) getNumSegments(nodeID int64) (int, error) { return numSegment, nil } -func (c *queryNodeCluster) registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID) error { +func (c *queryNodeCluster) registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error { c.Lock() defer c.Unlock() @@ -431,23 +482,17 @@ func (c *queryNodeCluster) registerNode(ctx context.Context, session *sessionuti if err != nil { return err } - c.nodes[id], err = c.newNodeFn(ctx, session.Address, id, c.client) + node, err := c.newNodeFn(ctx, session.Address, id, c.client) if err != nil { log.Debug("RegisterNode: create a new query node failed", zap.Int64("nodeID", id), zap.Error(err)) return err } + node.setState(state) + if state < online { + go node.start() + } + c.nodes[id] = node log.Debug("RegisterNode: create a new query node", zap.Int64("nodeID", id), zap.String("address", session.Address)) - - go func() { - err = c.nodes[id].start() - if err != nil { - log.Error("RegisterNode: start queryNode client failed", zap.Int64("nodeID", id), zap.String("error", err.Error())) - return - } - log.Debug("RegisterNode: start queryNode success, print cluster MetaReplica info", zap.Int64("nodeID", id)) - c.printMeta() - }() - return nil } return fmt.Errorf("RegisterNode: node %d alredy exists in cluster", id) @@ -496,56 +541,77 @@ func (c *queryNodeCluster) stopNode(nodeID int64) { } } -func (c *queryNodeCluster) onServiceNodes() (map[int64]Node, error) { +func (c *queryNodeCluster) onlineNodes() (map[int64]Node, error) { c.RLock() defer c.RUnlock() - return c.getOnServiceNodes() + return c.getOnlineNodes() } -func (c *queryNodeCluster) getOnServiceNodes() (map[int64]Node, error) { +func (c *queryNodeCluster) getOnlineNodes() (map[int64]Node, error) { nodes := make(map[int64]Node) for nodeID, node := range c.nodes { - if node.isOnService() { + if node.isOnline() { nodes[nodeID] = node } } if len(nodes) == 0 { - return nil, errors.New("GetOnServiceNodes: no queryNode is alive") + return nil, errors.New("GetOnlineNodes: no queryNode is alive") } return nodes, nil } -func (c *queryNodeCluster) isOnService(nodeID int64) (bool, error) { +func (c *queryNodeCluster) offlineNodes() (map[int64]Node, error) { + c.RLock() + defer c.RUnlock() + + return c.getOfflineNodes() +} + +func (c *queryNodeCluster) getOfflineNodes() (map[int64]Node, error) { + nodes := make(map[int64]Node) + for nodeID, node := range c.nodes { + if node.isOffline() { + nodes[nodeID] = node + } + } + if len(nodes) == 0 { + return nil, errors.New("GetOfflineNodes: no queryNode is offline") + } + + return nodes, nil +} + +func (c *queryNodeCluster) isOnline(nodeID int64) (bool, error) { c.Lock() defer c.Unlock() if node, ok := c.nodes[nodeID]; ok { - return node.isOnService(), nil + return node.isOnline(), nil } return false, fmt.Errorf("IsOnService: query node %d not exist", nodeID) } -func (c *queryNodeCluster) printMeta() { - c.RLock() - defer c.RUnlock() - - for id, node := range c.nodes { - if node.isOnService() { - collectionInfos := node.showCollections() - for _, info := range collectionInfos { - log.Debug("PrintMeta: query coordinator cluster info: collectionInfo", zap.Int64("nodeID", id), zap.Int64("collectionID", info.CollectionID), zap.Any("info", info)) - } - - queryChannelInfos := node.showWatchedQueryChannels() - for _, info := range queryChannelInfos { - log.Debug("PrintMeta: query coordinator cluster info: watchedQueryChannelInfo", zap.Int64("nodeID", id), zap.Int64("collectionID", info.CollectionID), zap.Any("info", info)) - } - } - } -} +//func (c *queryNodeCluster) printMeta() { +// c.RLock() +// defer c.RUnlock() +// +// for id, node := range c.nodes { +// if node.isOnline() { +// collectionInfos := node.showCollections() +// for _, info := range collectionInfos { +// log.Debug("PrintMeta: query coordinator cluster info: collectionInfo", zap.Int64("nodeID", id), zap.Int64("collectionID", info.CollectionID), zap.Any("info", info)) +// } +// +// queryChannelInfos := node.showWatchedQueryChannels() +// for _, info := range queryChannelInfos { +// log.Debug("PrintMeta: query coordinator cluster info: watchedQueryChannelInfo", zap.Int64("nodeID", id), zap.Int64("collectionID", info.CollectionID), zap.Any("info", info)) +// } +// } +// } +//} func (c *queryNodeCluster) getCollectionInfosByID(ctx context.Context, nodeID int64) []*querypb.CollectionInfo { c.RLock() diff --git a/internal/querycoord/cluster_test.go b/internal/querycoord/cluster_test.go index 9489f34705..f21c741194 100644 --- a/internal/querycoord/cluster_test.go +++ b/internal/querycoord/cluster_test.go @@ -24,6 +24,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" ) func TestQueryNodeCluster_getMetrics(t *testing.T) { @@ -31,38 +32,163 @@ func TestQueryNodeCluster_getMetrics(t *testing.T) { } func TestReloadClusterFromKV(t *testing.T) { + t.Run("Test LoadOnlineNodes", func(t *testing.T) { + refreshParams() + baseCtx := context.Background() + kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + assert.Nil(t, err) + clusterSession := sessionutil.NewSession(context.Background(), Params.MetaRootPath, Params.EtcdEndpoints) + clusterSession.Init(typeutil.QueryCoordRole, Params.Address, true) + cluster := &queryNodeCluster{ + ctx: baseCtx, + client: kv, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, + } + + queryNode, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + + cluster.reloadFromKV() + + nodeID := queryNode.queryNodeID + for { + _, err = cluster.getNodeByID(nodeID) + if err == nil { + break + } + } + queryNode.stop() + }) + + t.Run("Test LoadOfflineNodes", func(t *testing.T) { + refreshParams() + kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + assert.Nil(t, err) + clusterSession := sessionutil.NewSession(context.Background(), Params.MetaRootPath, Params.EtcdEndpoints) + clusterSession.Init(typeutil.QueryCoordRole, Params.Address, true) + cluster := &queryNodeCluster{ + client: kv, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, + } + + kvs := make(map[string]string) + session := &sessionutil.Session{ + ServerID: 100, + Address: "localhost", + } + sessionBlob, err := json.Marshal(session) + assert.Nil(t, err) + sessionKey := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, 100) + kvs[sessionKey] = string(sessionBlob) + + collectionInfo := &querypb.CollectionInfo{ + CollectionID: defaultCollectionID, + } + collectionBlobs := proto.MarshalTextString(collectionInfo) + nodeKey := fmt.Sprintf("%s/%d", queryNodeMetaPrefix, 100) + kvs[nodeKey] = collectionBlobs + + err = kv.MultiSave(kvs) + assert.Nil(t, err) + + cluster.reloadFromKV() + + assert.Equal(t, 1, len(cluster.nodes)) + collection := cluster.getCollectionInfosByID(context.Background(), 100) + assert.Equal(t, defaultCollectionID, collection[0].CollectionID) + }) +} + +func TestGrpcRequest(t *testing.T) { refreshParams() + baseCtx, cancel := context.WithCancel(context.Background()) kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) assert.Nil(t, err) + clusterSession := sessionutil.NewSession(context.Background(), Params.MetaRootPath, Params.EtcdEndpoints) + clusterSession.Init(typeutil.QueryCoordRole, Params.Address, true) + meta, err := newMeta(kv) + assert.Nil(t, err) cluster := &queryNodeCluster{ - client: kv, - nodes: make(map[int64]Node), - newNodeFn: newQueryNodeTest, + ctx: baseCtx, + cancel: cancel, + client: kv, + clusterMeta: meta, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, } - kvs := make(map[string]string) - session := &sessionutil.Session{ - ServerID: 100, - Address: "localhost", - } - sessionBlob, err := json.Marshal(session) + node, err := startQueryNodeServer(baseCtx) assert.Nil(t, err) - sessionKey := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, 100) - kvs[sessionKey] = string(sessionBlob) + nodeSession := node.session + nodeID := node.queryNodeID + cluster.registerNode(baseCtx, nodeSession, nodeID, disConnect) - collectionInfo := &querypb.CollectionInfo{ - CollectionID: defaultCollectionID, + for { + online, err := cluster.isOnline(nodeID) + assert.Nil(t, err) + if online { + break + } } - collectionBlobs := proto.MarshalTextString(collectionInfo) - nodeKey := fmt.Sprintf("%s/%d", queryNodeMetaPrefix, 100) - kvs[nodeKey] = collectionBlobs - err = kv.MultiSave(kvs) - assert.Nil(t, err) + t.Run("Test GetComponentInfos", func(t *testing.T) { + _, err := cluster.getComponentInfos(baseCtx) + assert.Nil(t, err) + }) - cluster.reloadFromKV() + t.Run("Test LoadSegments", func(t *testing.T) { + segmentLoadInfo := &querypb.SegmentLoadInfo{ + SegmentID: defaultSegmentID, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + } + loadSegmentReq := &querypb.LoadSegmentsRequest{ + NodeID: nodeID, + Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo}, + } + err := cluster.loadSegments(baseCtx, nodeID, loadSegmentReq) + assert.Nil(t, err) + }) - assert.Equal(t, 1, len(cluster.nodes)) - collection := cluster.getCollectionInfosByID(context.Background(), 100) - assert.Equal(t, defaultCollectionID, collection[0].CollectionID) + t.Run("Test ReleaseSegments", func(t *testing.T) { + releaseSegmentReq := &querypb.ReleaseSegmentsRequest{ + NodeID: nodeID, + CollectionID: defaultCollectionID, + PartitionIDs: []UniqueID{defaultPartitionID}, + SegmentIDs: []UniqueID{defaultSegmentID}, + } + err := cluster.releaseSegments(baseCtx, nodeID, releaseSegmentReq) + assert.Nil(t, err) + }) + + t.Run("Test AddQueryChannel", func(t *testing.T) { + reqChannel, resChannel := cluster.clusterMeta.GetQueryChannel(defaultCollectionID) + addQueryChannelReq := &querypb.AddQueryChannelRequest{ + NodeID: nodeID, + CollectionID: defaultCollectionID, + RequestChannelID: reqChannel, + ResultChannelID: resChannel, + } + err := cluster.addQueryChannel(baseCtx, nodeID, addQueryChannelReq) + assert.Nil(t, err) + }) + + t.Run("Test RemoveQueryChannel", func(t *testing.T) { + reqChannel, resChannel := cluster.clusterMeta.GetQueryChannel(defaultCollectionID) + removeQueryChannelReq := &querypb.RemoveQueryChannelRequest{ + NodeID: nodeID, + CollectionID: defaultCollectionID, + RequestChannelID: reqChannel, + ResultChannelID: resChannel, + } + err := cluster.removeQueryChannel(baseCtx, nodeID, removeQueryChannelReq) + assert.Nil(t, err) + }) + + node.stop() } diff --git a/internal/querycoord/impl_test.go b/internal/querycoord/impl_test.go index e3cc3dc980..a13865f748 100644 --- a/internal/querycoord/impl_test.go +++ b/internal/querycoord/impl_test.go @@ -1,3 +1,13 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. package querycoord import ( @@ -16,6 +26,7 @@ import ( ) func TestGrpcTask(t *testing.T) { + refreshParams() ctx := context.Background() queryCoord, err := startQueryCoord(ctx) assert.Nil(t, err) @@ -321,6 +332,7 @@ func TestGrpcTask(t *testing.T) { } func TestLoadBalanceTask(t *testing.T) { + refreshParams() baseCtx := context.Background() queryCoord, err := startQueryCoord(baseCtx) diff --git a/internal/querycoord/mock_querynode_client_test.go b/internal/querycoord/mock_querynode_client_test.go index d04d592848..960e3944c9 100644 --- a/internal/querycoord/mock_querynode_client_test.go +++ b/internal/querycoord/mock_querynode_client_test.go @@ -1,3 +1,13 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. package querycoord import ( @@ -43,7 +53,6 @@ func newQueryNodeTest(ctx context.Context, address string, id UniqueID, kv *etcd kvClient: kv, collectionInfos: collectionInfo, watchedQueryChannels: watchedChannels, - onService: false, } return node, nil @@ -80,7 +89,10 @@ func (client *queryNodeClientMock) Start() error { func (client *queryNodeClientMock) Stop() error { client.cancel() - return client.conn.Close() + if client.conn != nil { + return client.conn.Close() + } + return nil } func (client *queryNodeClientMock) Register() error { diff --git a/internal/querycoord/mock_querynode_server_test.go b/internal/querycoord/mock_querynode_server_test.go index 74f720833f..5d2987335e 100644 --- a/internal/querycoord/mock_querynode_server_test.go +++ b/internal/querycoord/mock_querynode_server_test.go @@ -1,3 +1,13 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. package querycoord import ( @@ -11,6 +21,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/funcutil" @@ -31,12 +42,13 @@ type queryNodeServerMock struct { queryNodePort int64 queryNodeID int64 - addQueryChannels func() (*commonpb.Status, error) - watchDmChannels func() (*commonpb.Status, error) - loadSegment func() (*commonpb.Status, error) - releaseCollection func() (*commonpb.Status, error) - releasePartition func() (*commonpb.Status, error) - releaseSegment func() (*commonpb.Status, error) + addQueryChannels func() (*commonpb.Status, error) + removeQueryChannels func() (*commonpb.Status, error) + watchDmChannels func() (*commonpb.Status, error) + loadSegment func() (*commonpb.Status, error) + releaseCollection func() (*commonpb.Status, error) + releasePartition func() (*commonpb.Status, error) + releaseSegments func() (*commonpb.Status, error) } func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock { @@ -46,12 +58,13 @@ func newQueryNodeServerMock(ctx context.Context) *queryNodeServerMock { cancel: cancel, grpcErrChan: make(chan error), - addQueryChannels: returnSuccessResult, - watchDmChannels: returnSuccessResult, - loadSegment: returnSuccessResult, - releaseCollection: returnSuccessResult, - releasePartition: returnSuccessResult, - releaseSegment: returnSuccessResult, + addQueryChannels: returnSuccessResult, + removeQueryChannels: returnSuccessResult, + watchDmChannels: returnSuccessResult, + loadSegment: returnSuccessResult, + releaseCollection: returnSuccessResult, + releasePartition: returnSuccessResult, + releaseSegments: returnSuccessResult, } } @@ -83,7 +96,7 @@ func (qs *queryNodeServerMock) init() error { grpcPort = 0 } return err - }, retry.Attempts(10)) + }, retry.Attempts(2)) if err != nil { qs.grpcErrChan <- err } @@ -133,10 +146,22 @@ func (qs *queryNodeServerMock) run() error { return nil } +func (qs *queryNodeServerMock) GetComponentStates(ctx context.Context, req *internalpb.GetComponentStatesRequest) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + }, nil +} + func (qs *queryNodeServerMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) { return qs.addQueryChannels() } +func (qs *queryNodeServerMock) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQueryChannelRequest) (*commonpb.Status, error) { + return qs.removeQueryChannels() +} + func (qs *queryNodeServerMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { return qs.watchDmChannels() } @@ -154,7 +179,7 @@ func (qs *queryNodeServerMock) ReleasePartitions(ctx context.Context, req *query } func (qs *queryNodeServerMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - return qs.releaseSegment() + return qs.releaseSegments() } func (qs *queryNodeServerMock) GetSegmentInfo(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { diff --git a/internal/querycoord/query_coord.go b/internal/querycoord/query_coord.go index cd5a067e18..e581038b87 100644 --- a/internal/querycoord/query_coord.go +++ b/internal/querycoord/query_coord.go @@ -20,8 +20,6 @@ import ( "sync/atomic" "time" - "github.com/milvus-io/milvus/internal/util/metricsinfo" - "github.com/golang/protobuf/proto" "go.etcd.io/etcd/api/v3/mvccpb" "go.uber.org/zap" @@ -33,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -53,7 +52,7 @@ type QueryCoord struct { queryCoordID uint64 meta Meta - cluster *queryNodeCluster + cluster Cluster newNodeFn newQueryNodeFn scheduler *TaskScheduler @@ -103,7 +102,7 @@ func (qc *QueryCoord) Init() error { return err } - qc.cluster, err = newQueryNodeCluster(qc.meta, qc.kvClient, qc.newNodeFn) + qc.cluster, err = newQueryNodeCluster(qc.loopCtx, qc.meta, qc.kvClient, qc.newNodeFn, qc.session) if err != nil { log.Error("query coordinator init cluster failed", zap.Error(err)) return err @@ -189,50 +188,37 @@ func (qc *QueryCoord) watchNodeLoop() { defer qc.loopWg.Done() log.Debug("query coordinator start watch node loop") - clusterStartSession, version, _ := qc.session.GetSessions(typeutil.QueryNodeRole) - sessionMap := make(map[int64]*sessionutil.Session) - for _, session := range clusterStartSession { - nodeID := session.ServerID - sessionMap[nodeID] = session - } - for nodeID, session := range sessionMap { - if _, ok := qc.cluster.nodes[nodeID]; !ok { - serverID := session.ServerID - log.Debug("start add a queryNode to cluster", zap.Any("nodeID", serverID)) - err := qc.cluster.registerNode(ctx, session, serverID) - if err != nil { - log.Error("query node failed to register", zap.Int64("nodeID", serverID), zap.String("error info", err.Error())) - } + offlineNodes, err := qc.cluster.offlineNodes() + if err == nil { + offlineNodeIDs := make([]int64, 0) + for id := range offlineNodes { + offlineNodeIDs = append(offlineNodeIDs, id) + } + loadBalanceSegment := &querypb.LoadBalanceRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadBalanceSegments, + SourceID: qc.session.ServerID, + }, + SourceNodeIDs: offlineNodeIDs, } - } - for nodeID := range qc.cluster.nodes { - if _, ok := sessionMap[nodeID]; !ok { - qc.cluster.stopNode(nodeID) - loadBalanceSegment := &querypb.LoadBalanceRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_LoadBalanceSegments, - SourceID: qc.session.ServerID, - }, - SourceNodeIDs: []int64{nodeID}, - } - loadBalanceTask := &LoadBalanceTask{ - BaseTask: BaseTask{ - ctx: qc.loopCtx, - Condition: NewTaskCondition(qc.loopCtx), - triggerCondition: querypb.TriggerCondition_nodeDown, - }, - LoadBalanceRequest: loadBalanceSegment, - rootCoord: qc.rootCoordClient, - dataCoord: qc.dataCoordClient, - cluster: qc.cluster, - meta: qc.meta, - } - qc.scheduler.Enqueue([]task{loadBalanceTask}) + loadBalanceTask := &LoadBalanceTask{ + BaseTask: BaseTask{ + ctx: qc.loopCtx, + Condition: NewTaskCondition(qc.loopCtx), + triggerCondition: querypb.TriggerCondition_nodeDown, + }, + LoadBalanceRequest: loadBalanceSegment, + rootCoord: qc.rootCoordClient, + dataCoord: qc.dataCoordClient, + cluster: qc.cluster, + meta: qc.meta, } + qc.scheduler.Enqueue([]task{loadBalanceTask}) + log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask)) } - qc.eventChan = qc.session.WatchServices(typeutil.QueryNodeRole, version+1) + qc.eventChan = qc.session.WatchServices(typeutil.QueryNodeRole, qc.cluster.getSessionVersion()+1) for { select { case <-ctx.Done(): @@ -242,7 +228,7 @@ func (qc *QueryCoord) watchNodeLoop() { case sessionutil.SessionAddEvent: serverID := event.Session.ServerID log.Debug("start add a queryNode to cluster", zap.Any("nodeID", serverID)) - err := qc.cluster.registerNode(ctx, event.Session, serverID) + err := qc.cluster.registerNode(ctx, event.Session, serverID, disConnect) if err != nil { log.Error("query node failed to register", zap.Int64("nodeID", serverID), zap.String("error info", err.Error())) } @@ -279,6 +265,7 @@ func (qc *QueryCoord) watchNodeLoop() { meta: qc.meta, } qc.scheduler.Enqueue([]task{loadBalanceTask}) + log.Debug("start a loadBalance task", zap.Any("task", loadBalanceTask)) qc.metricsCacheManager.InvalidateSystemInfoMetrics() } } diff --git a/internal/querycoord/query_coord_test.go b/internal/querycoord/query_coord_test.go index 402f470970..00f53d2d18 100644 --- a/internal/querycoord/query_coord_test.go +++ b/internal/querycoord/query_coord_test.go @@ -13,17 +13,21 @@ package querycoord import ( "context" + "encoding/json" + "fmt" "math/rand" "os" "strconv" "testing" "time" - "go.uber.org/zap" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" - "github.com/milvus-io/milvus/internal/log" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/msgstream" - "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/sessionutil" ) func setup() { @@ -46,29 +50,158 @@ func TestMain(m *testing.M) { } func NewQueryCoordTest(ctx context.Context, factory msgstream.Factory) (*QueryCoord, error) { - refreshParams() - rand.Seed(time.Now().UnixNano()) - queryChannels := make([]*queryChannelInfo, 0) - channelID := len(queryChannels) - searchPrefix := Params.SearchChannelPrefix - searchResultPrefix := Params.SearchResultChannelPrefix - allocatedQueryChannel := searchPrefix + "-" + strconv.FormatInt(int64(channelID), 10) - allocatedQueryResultChannel := searchResultPrefix + "-" + strconv.FormatInt(int64(channelID), 10) + queryCoord, err := NewQueryCoord(ctx, factory) + if err != nil { + return nil, err + } + queryCoord.newNodeFn = newQueryNodeTest + return queryCoord, nil +} - queryChannels = append(queryChannels, &queryChannelInfo{ - requestChannel: allocatedQueryChannel, - responseChannel: allocatedQueryResultChannel, - }) +func startQueryCoord(ctx context.Context) (*QueryCoord, error) { + factory := msgstream.NewPmsFactory() - ctx1, cancel := context.WithCancel(ctx) - service := &QueryCoord{ - loopCtx: ctx1, - loopCancel: cancel, - msFactory: factory, - newNodeFn: newQueryNodeTest, + coord, err := NewQueryCoordTest(ctx, factory) + if err != nil { + return nil, err } - service.UpdateStateCode(internalpb.StateCode_Abnormal) - log.Debug("query coordinator", zap.Any("queryChannels", queryChannels)) - return service, nil + rootCoord := newRootCoordMock() + rootCoord.createCollection(defaultCollectionID) + rootCoord.createPartition(defaultCollectionID, defaultPartitionID) + + dataCoord, err := newDataCoordMock(ctx) + if err != nil { + return nil, err + } + + coord.SetRootCoord(rootCoord) + coord.SetDataCoord(dataCoord) + + err = coord.Register() + if err != nil { + return nil, err + } + err = coord.Init() + if err != nil { + return nil, err + } + err = coord.Start() + if err != nil { + return nil, err + } + return coord, nil +} + +func startUnHealthyQueryCoord(ctx context.Context) (*QueryCoord, error) { + factory := msgstream.NewPmsFactory() + + coord, err := NewQueryCoordTest(ctx, factory) + if err != nil { + return nil, err + } + + rootCoord := newRootCoordMock() + rootCoord.createCollection(defaultCollectionID) + rootCoord.createPartition(defaultCollectionID, defaultPartitionID) + + dataCoord, err := newDataCoordMock(ctx) + if err != nil { + return nil, err + } + + coord.SetRootCoord(rootCoord) + coord.SetDataCoord(dataCoord) + + err = coord.Register() + if err != nil { + return nil, err + } + err = coord.Init() + if err != nil { + return nil, err + } + + return coord, nil +} + +func TestWatchNodeLoop(t *testing.T) { + baseCtx := context.Background() + + t.Run("Test OfflineNodes", func(t *testing.T) { + refreshParams() + kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + assert.Nil(t, err) + + kvs := make(map[string]string) + session := &sessionutil.Session{ + ServerID: 100, + Address: "localhost", + } + sessionBlob, err := json.Marshal(session) + assert.Nil(t, err) + sessionKey := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, 100) + kvs[sessionKey] = string(sessionBlob) + + collectionInfo := &querypb.CollectionInfo{ + CollectionID: defaultCollectionID, + } + collectionBlobs := proto.MarshalTextString(collectionInfo) + nodeKey := fmt.Sprintf("%s/%d", queryNodeMetaPrefix, 100) + kvs[nodeKey] = collectionBlobs + + err = kv.MultiSave(kvs) + assert.Nil(t, err) + + queryCoord, err := startQueryCoord(baseCtx) + assert.Nil(t, err) + + for { + _, err = queryCoord.cluster.offlineNodes() + if err == nil { + break + } + } + + queryCoord.Stop() + }) + + t.Run("Test RegisterNewNode", func(t *testing.T) { + refreshParams() + queryCoord, err := startQueryCoord(baseCtx) + assert.Nil(t, err) + + queryNode1, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + + nodeID := queryNode1.queryNodeID + for { + _, err = queryCoord.cluster.getNodeByID(nodeID) + if err == nil { + break + } + } + + queryCoord.Stop() + queryNode1.stop() + }) + + t.Run("Test RemoveNode", func(t *testing.T) { + refreshParams() + queryNode1, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + + queryCoord, err := startQueryCoord(baseCtx) + assert.Nil(t, err) + + nodeID := queryNode1.queryNodeID + queryNode1.stop() + for { + _, err = queryCoord.cluster.getNodeByID(nodeID) + if err != nil { + break + } + } + queryCoord.Stop() + }) } diff --git a/internal/querycoord/querynode.go b/internal/querycoord/querynode.go index e4cc3c8ec9..12e10fd5aa 100644 --- a/internal/querycoord/querynode.go +++ b/internal/querycoord/querynode.go @@ -17,8 +17,6 @@ import ( "fmt" "sync" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/golang/protobuf/proto" "go.uber.org/zap" @@ -27,6 +25,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/types" @@ -46,15 +45,17 @@ type Node interface { releasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) error watchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) error - removeDmChannel(collectionID UniqueID, channels []string) error + //removeDmChannel(collectionID UniqueID, channels []string) error hasWatchedQueryChannel(collectionID UniqueID) bool - showWatchedQueryChannels() []*querypb.QueryChannelInfo + //showWatchedQueryChannels() []*querypb.QueryChannelInfo addQueryChannel(ctx context.Context, in *querypb.AddQueryChannelRequest) error removeQueryChannel(ctx context.Context, in *querypb.RemoveQueryChannelRequest) error - setNodeState(onService bool) - isOnService() bool + setState(state nodeState) + getState() nodeState + isOnline() bool + isOffline() bool getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) loadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) error @@ -75,8 +76,8 @@ type queryNode struct { sync.RWMutex collectionInfos map[UniqueID]*querypb.CollectionInfo watchedQueryChannels map[UniqueID]*querypb.QueryChannelInfo - onService bool - serviceLock sync.RWMutex + state nodeState + stateLock sync.RWMutex } func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error) { @@ -97,7 +98,7 @@ func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.E kvClient: kv, collectionInfos: collectionInfo, watchedQueryChannels: watchedChannels, - onService: false, + state: disConnect, } return node, nil @@ -105,23 +106,27 @@ func newQueryNode(ctx context.Context, address string, id UniqueID, kv *etcdkv.E func (qn *queryNode) start() error { if err := qn.client.Init(); err != nil { + log.Error("Start: init queryNode client failed", zap.Int64("nodeID", qn.id), zap.String("error", err.Error())) return err } if err := qn.client.Start(); err != nil { + log.Error("Start: start queryNode client failed", zap.Int64("nodeID", qn.id), zap.String("error", err.Error())) return err } - qn.serviceLock.Lock() - qn.onService = true - qn.serviceLock.Unlock() + qn.stateLock.Lock() + if qn.state < online { + qn.state = online + } + qn.stateLock.Unlock() log.Debug("Start: queryNode client start success", zap.Int64("nodeID", qn.id), zap.String("address", qn.address)) return nil } func (qn *queryNode) stop() { - qn.serviceLock.Lock() - defer qn.serviceLock.Unlock() - qn.onService = false + qn.stateLock.Lock() + defer qn.stateLock.Unlock() + qn.state = offline if qn.client != nil { qn.client.Stop() } @@ -272,37 +277,37 @@ func (qn *queryNode) addDmChannel(collectionID UniqueID, channels []string) erro return errors.New("AddDmChannels: can't find collection in watchedQueryChannel") } -func (qn *queryNode) removeDmChannel(collectionID UniqueID, channels []string) error { - qn.Lock() - defer qn.Unlock() - - if info, ok := qn.collectionInfos[collectionID]; ok { - for _, channelInfo := range info.ChannelInfos { - if channelInfo.NodeIDLoaded == qn.id { - newChannelIDs := make([]string, 0) - for _, channelID := range channelInfo.ChannelIDs { - findChannel := false - for _, channel := range channels { - if channelID == channel { - findChannel = true - } - } - if !findChannel { - newChannelIDs = append(newChannelIDs, channelID) - } - } - channelInfo.ChannelIDs = newChannelIDs - } - } - - err := saveNodeCollectionInfo(collectionID, info, qn.id, qn.kvClient) - if err != nil { - log.Error("RemoveDmChannel: save collectionInfo error", zap.Any("error", err.Error()), zap.Int64("collectionID", collectionID)) - } - } - - return errors.New("RemoveDmChannel: can't find collection in watchedQueryChannel") -} +//func (qn *queryNode) removeDmChannel(collectionID UniqueID, channels []string) error { +// qn.Lock() +// defer qn.Unlock() +// +// if info, ok := qn.collectionInfos[collectionID]; ok { +// for _, channelInfo := range info.ChannelInfos { +// if channelInfo.NodeIDLoaded == qn.id { +// newChannelIDs := make([]string, 0) +// for _, channelID := range channelInfo.ChannelIDs { +// findChannel := false +// for _, channel := range channels { +// if channelID == channel { +// findChannel = true +// } +// } +// if !findChannel { +// newChannelIDs = append(newChannelIDs, channelID) +// } +// } +// channelInfo.ChannelIDs = newChannelIDs +// } +// } +// +// err := saveNodeCollectionInfo(collectionID, info, qn.id, qn.kvClient) +// if err != nil { +// log.Error("RemoveDmChannel: save collectionInfo error", zap.Any("error", err.Error()), zap.Int64("collectionID", collectionID)) +// } +// } +// +// return errors.New("RemoveDmChannel: can't find collection in watchedQueryChannel") +//} func (qn *queryNode) hasWatchedQueryChannel(collectionID UniqueID) bool { qn.RLock() @@ -315,17 +320,17 @@ func (qn *queryNode) hasWatchedQueryChannel(collectionID UniqueID) bool { return false } -func (qn *queryNode) showWatchedQueryChannels() []*querypb.QueryChannelInfo { - qn.RLock() - defer qn.RUnlock() - - results := make([]*querypb.QueryChannelInfo, 0) - for _, info := range qn.watchedQueryChannels { - results = append(results, proto.Clone(info).(*querypb.QueryChannelInfo)) - } - - return results -} +//func (qn *queryNode) showWatchedQueryChannels() []*querypb.QueryChannelInfo { +// qn.RLock() +// defer qn.RUnlock() +// +// results := make([]*querypb.QueryChannelInfo, 0) +// for _, info := range qn.watchedQueryChannels { +// results = append(results, proto.Clone(info).(*querypb.QueryChannelInfo)) +// } +// +// return results +//} func (qn *queryNode) setQueryChannelInfo(info *querypb.QueryChannelInfo) { qn.Lock() @@ -354,26 +359,37 @@ func (qn *queryNode) clearNodeInfo() error { return nil } -func (qn *queryNode) setNodeState(onService bool) { - qn.serviceLock.Lock() - defer qn.serviceLock.Unlock() +func (qn *queryNode) setState(state nodeState) { + qn.stateLock.Lock() + defer qn.stateLock.Unlock() - qn.onService = onService + qn.state = state } -func (qn *queryNode) isOnService() bool { - qn.serviceLock.RLock() - defer qn.serviceLock.RUnlock() +func (qn *queryNode) getState() nodeState { + qn.stateLock.RLock() + defer qn.stateLock.RUnlock() - return qn.onService + return qn.state +} + +func (qn *queryNode) isOnline() bool { + qn.stateLock.RLock() + defer qn.stateLock.RUnlock() + + return qn.state == online +} + +func (qn *queryNode) isOffline() bool { + qn.stateLock.RLock() + defer qn.stateLock.RUnlock() + + return qn.state == offline } //***********************grpc req*************************// func (qn *queryNode) watchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) error { - qn.serviceLock.RLock() - onService := qn.onService - qn.serviceLock.RUnlock() - if !onService { + if !qn.isOnline() { return errors.New("WatchDmChannels: queryNode is offline") } @@ -397,10 +413,7 @@ func (qn *queryNode) watchDmChannels(ctx context.Context, in *querypb.WatchDmCha } func (qn *queryNode) addQueryChannel(ctx context.Context, in *querypb.AddQueryChannelRequest) error { - qn.serviceLock.RLock() - onService := qn.onService - qn.serviceLock.RUnlock() - if !onService { + if !qn.isOnline() { return errors.New("AddQueryChannel: queryNode is offline") } @@ -422,10 +435,7 @@ func (qn *queryNode) addQueryChannel(ctx context.Context, in *querypb.AddQueryCh } func (qn *queryNode) removeQueryChannel(ctx context.Context, in *querypb.RemoveQueryChannelRequest) error { - qn.serviceLock.RLock() - onService := qn.onService - qn.serviceLock.RUnlock() - if !onService { + if !qn.isOnline() { return nil } @@ -442,10 +452,7 @@ func (qn *queryNode) removeQueryChannel(ctx context.Context, in *querypb.RemoveQ } func (qn *queryNode) releaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) error { - qn.serviceLock.RLock() - onService := qn.onService - qn.serviceLock.RUnlock() - if !onService { + if qn.isOnline() { return nil } @@ -466,10 +473,7 @@ func (qn *queryNode) releaseCollection(ctx context.Context, in *querypb.ReleaseC } func (qn *queryNode) releasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) error { - qn.serviceLock.RLock() - onService := qn.onService - qn.serviceLock.RUnlock() - if !onService { + if !qn.isOnline() { return nil } @@ -489,11 +493,9 @@ func (qn *queryNode) releasePartitions(ctx context.Context, in *querypb.ReleaseP } func (qn *queryNode) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - qn.serviceLock.RLock() - if !qn.onService { + if !qn.isOnline() { return nil, nil } - qn.serviceLock.RUnlock() res, err := qn.client.GetSegmentInfo(ctx, in) if err == nil && res.Status.ErrorCode == commonpb.ErrorCode_Success { @@ -504,14 +506,12 @@ func (qn *queryNode) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentI } func (qn *queryNode) getComponentInfo(ctx context.Context) *internalpb.ComponentInfo { - qn.serviceLock.RLock() - if !qn.onService { + if !qn.isOnline() { return &internalpb.ComponentInfo{ NodeID: qn.id, StateCode: internalpb.StateCode_Abnormal, } } - qn.serviceLock.RUnlock() res, err := qn.client.GetComponentStates(ctx) if err != nil || res.Status.ErrorCode != commonpb.ErrorCode_Success { @@ -525,20 +525,15 @@ func (qn *queryNode) getComponentInfo(ctx context.Context) *internalpb.Component } func (qn *queryNode) getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - qn.serviceLock.RLock() - if !qn.onService { + if !qn.isOnline() { return nil, errQueryNodeIsNotOnService(qn.id) } - qn.serviceLock.RUnlock() return qn.client.GetMetrics(ctx, in) } func (qn *queryNode) loadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) error { - qn.serviceLock.RLock() - onService := qn.onService - qn.serviceLock.RUnlock() - if !onService { + if !qn.isOnline() { return errors.New("LoadSegments: queryNode is offline") } @@ -564,10 +559,7 @@ func (qn *queryNode) loadSegments(ctx context.Context, in *querypb.LoadSegmentsR } func (qn *queryNode) releaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) error { - qn.serviceLock.RLock() - onService := qn.onService - qn.serviceLock.RUnlock() - if !onService { + if !qn.isOnline() { return errors.New("ReleaseSegments: queryNode is offline") } diff --git a/internal/querycoord/querynode_test.go b/internal/querycoord/querynode_test.go index 2d923362e2..7f9e1345a1 100644 --- a/internal/querycoord/querynode_test.go +++ b/internal/querycoord/querynode_test.go @@ -18,82 +18,15 @@ import ( "github.com/stretchr/testify/assert" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/querypb" ) -func startQueryCoord(ctx context.Context) (*QueryCoord, error) { - factory := msgstream.NewPmsFactory() - - coord, err := NewQueryCoordTest(ctx, factory) - if err != nil { - return nil, err - } - - rootCoord := newRootCoordMock() - rootCoord.createCollection(defaultCollectionID) - rootCoord.createPartition(defaultCollectionID, defaultPartitionID) - - dataCoord, err := newDataCoordMock(ctx) - if err != nil { - return nil, err - } - - coord.SetRootCoord(rootCoord) - coord.SetDataCoord(dataCoord) - - err = coord.Register() - if err != nil { - return nil, err - } - err = coord.Init() - if err != nil { - return nil, err - } - err = coord.Start() - if err != nil { - return nil, err - } - return coord, nil -} - -func startUnHealthyQueryCoord(ctx context.Context) (*QueryCoord, error) { - factory := msgstream.NewPmsFactory() - - coord, err := NewQueryCoordTest(ctx, factory) - if err != nil { - return nil, err - } - - rootCoord := newRootCoordMock() - rootCoord.createCollection(defaultCollectionID) - rootCoord.createPartition(defaultCollectionID, defaultPartitionID) - - dataCoord, err := newDataCoordMock(ctx) - if err != nil { - return nil, err - } - - coord.SetRootCoord(rootCoord) - coord.SetDataCoord(dataCoord) - - err = coord.Register() - if err != nil { - return nil, err - } - err = coord.Init() - if err != nil { - return nil, err - } - - return coord, nil -} - //func waitQueryNodeOnline(cluster *queryNodeCluster, nodeID int64) -func waitAllQueryNodeOffline(cluster *queryNodeCluster, nodes map[int64]Node) bool { +func waitAllQueryNodeOffline(cluster Cluster, nodes map[int64]Node) bool { reDoCount := 20 for { if reDoCount <= 0 { @@ -117,6 +50,7 @@ func waitAllQueryNodeOffline(cluster *queryNodeCluster, nodes map[int64]Node) bo } func TestQueryNode_MultiNode_stop(t *testing.T) { + refreshParams() baseCtx := context.Background() queryCoord, err := startQueryCoord(baseCtx) @@ -147,7 +81,7 @@ func TestQueryNode_MultiNode_stop(t *testing.T) { }) assert.Nil(t, err) time.Sleep(2 * time.Second) - nodes, err := queryCoord.cluster.onServiceNodes() + nodes, err := queryCoord.cluster.onlineNodes() assert.Nil(t, err) queryNode5.stop() @@ -157,6 +91,7 @@ func TestQueryNode_MultiNode_stop(t *testing.T) { } func TestQueryNode_MultiNode_reStart(t *testing.T) { + refreshParams() baseCtx := context.Background() queryCoord, err := startQueryCoord(baseCtx) @@ -185,7 +120,7 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) { CollectionID: defaultCollectionID, }) assert.Nil(t, err) - nodes, err := queryCoord.cluster.onServiceNodes() + nodes, err := queryCoord.cluster.onlineNodes() assert.Nil(t, err) queryNode3.stop() @@ -197,3 +132,25 @@ func TestQueryNode_MultiNode_reStart(t *testing.T) { func TestQueryNode_getMetrics(t *testing.T) { log.Info("TestQueryNode_getMetrics, todo") } + +func TestNewQueryNode(t *testing.T) { + refreshParams() + baseCtx, cancel := context.WithCancel(context.Background()) + kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + assert.Nil(t, err) + + queryNode1, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + + addr := queryNode1.session.Address + nodeID := queryNode1.queryNodeID + node, err := newQueryNode(baseCtx, addr, nodeID, kv) + assert.Nil(t, err) + + err = node.start() + assert.Nil(t, err) + + cancel() + node.stop() + queryNode1.stop() +} diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 9f2e7a493a..0612561925 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -134,7 +134,7 @@ type LoadCollectionTask struct { *querypb.LoadCollectionRequest rootCoord types.RootCoord dataCoord types.DataCoord - cluster *queryNodeCluster + cluster Cluster meta Meta } @@ -323,7 +323,7 @@ func (lct *LoadCollectionTask) PostExecute(ctx context.Context) error { } if lct.result.ErrorCode != commonpb.ErrorCode_Success { lct.childTasks = make([]task, 0) - nodes, err := lct.cluster.onServiceNodes() + nodes, err := lct.cluster.onlineNodes() if err != nil { log.Debug(err.Error()) } @@ -362,7 +362,7 @@ func (lct *LoadCollectionTask) PostExecute(ctx context.Context) error { type ReleaseCollectionTask struct { BaseTask *querypb.ReleaseCollectionRequest - cluster *queryNodeCluster + cluster Cluster meta Meta rootCoord types.RootCoord } @@ -427,7 +427,7 @@ func (rct *ReleaseCollectionTask) Execute(ctx context.Context) error { return err } - nodes, err := rct.cluster.onServiceNodes() + nodes, err := rct.cluster.onlineNodes() if err != nil { log.Debug(err.Error()) } @@ -477,7 +477,7 @@ type LoadPartitionTask struct { BaseTask *querypb.LoadPartitionsRequest dataCoord types.DataCoord - cluster *queryNodeCluster + cluster Cluster meta Meta addCol bool } @@ -606,7 +606,7 @@ func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error { if lpt.result.ErrorCode != commonpb.ErrorCode_Success { lpt.childTasks = make([]task, 0) if lpt.addCol { - nodes, err := lpt.cluster.onServiceNodes() + nodes, err := lpt.cluster.onlineNodes() if err != nil { log.Debug(err.Error()) } @@ -635,7 +635,7 @@ func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error { log.Debug("loadPartitionTask: add a releaseCollectionTask to loadPartitionTask's childTask", zap.Any("task", releaseCollectionTask)) } } else { - nodes, err := lpt.cluster.onServiceNodes() + nodes, err := lpt.cluster.onlineNodes() if err != nil { log.Debug(err.Error()) } @@ -678,7 +678,7 @@ func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error { type ReleasePartitionTask struct { BaseTask *querypb.ReleasePartitionsRequest - cluster *queryNodeCluster + cluster Cluster } func (rpt *ReleasePartitionTask) MsgBase() *commonpb.MsgBase { @@ -717,7 +717,7 @@ func (rpt *ReleasePartitionTask) Execute(ctx context.Context) error { } if rpt.NodeID <= 0 { - nodes, err := rpt.cluster.onServiceNodes() + nodes, err := rpt.cluster.onlineNodes() if err != nil { log.Debug(err.Error()) } @@ -772,7 +772,7 @@ type LoadSegmentTask struct { BaseTask *querypb.LoadSegmentsRequest meta Meta - cluster *queryNodeCluster + cluster Cluster } func (lst *LoadSegmentTask) MsgBase() *commonpb.MsgBase { @@ -784,7 +784,7 @@ func (lst *LoadSegmentTask) Marshal() ([]byte, error) { } func (lst *LoadSegmentTask) IsValid() bool { - onService, err := lst.cluster.isOnService(lst.NodeID) + onService, err := lst.cluster.isOnline(lst.NodeID) if err != nil { return false } @@ -909,7 +909,7 @@ func (lst *LoadSegmentTask) Reschedule() ([]task, error) { type ReleaseSegmentTask struct { BaseTask *querypb.ReleaseSegmentsRequest - cluster *queryNodeCluster + cluster Cluster } func (rst *ReleaseSegmentTask) MsgBase() *commonpb.MsgBase { @@ -921,7 +921,7 @@ func (rst *ReleaseSegmentTask) Marshal() ([]byte, error) { } func (rst *ReleaseSegmentTask) IsValid() bool { - onService, err := rst.cluster.isOnService(rst.NodeID) + onService, err := rst.cluster.isOnline(rst.NodeID) if err != nil { return false } @@ -979,7 +979,7 @@ type WatchDmChannelTask struct { BaseTask *querypb.WatchDmChannelsRequest meta Meta - cluster *queryNodeCluster + cluster Cluster } func (wdt *WatchDmChannelTask) MsgBase() *commonpb.MsgBase { @@ -991,7 +991,7 @@ func (wdt *WatchDmChannelTask) Marshal() ([]byte, error) { } func (wdt *WatchDmChannelTask) IsValid() bool { - onService, err := wdt.cluster.isOnService(wdt.NodeID) + onService, err := wdt.cluster.isOnline(wdt.NodeID) if err != nil { return false } @@ -1120,7 +1120,7 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) { type WatchQueryChannelTask struct { BaseTask *querypb.AddQueryChannelRequest - cluster *queryNodeCluster + cluster Cluster } func (wqt *WatchQueryChannelTask) MsgBase() *commonpb.MsgBase { @@ -1132,7 +1132,7 @@ func (wqt *WatchQueryChannelTask) Marshal() ([]byte, error) { } func (wqt *WatchQueryChannelTask) IsValid() bool { - onService, err := wqt.cluster.isOnService(wqt.NodeID) + onService, err := wqt.cluster.isOnline(wqt.NodeID) if err != nil { return false } @@ -1201,7 +1201,7 @@ type LoadBalanceTask struct { *querypb.LoadBalanceRequest rootCoord types.RootCoord dataCoord types.DataCoord - cluster *queryNodeCluster + cluster Cluster meta Meta } @@ -1379,12 +1379,12 @@ func (lbt *LoadBalanceTask) PostExecute(context.Context) error { return nil } -func shuffleChannelsToQueryNode(dmChannels []string, cluster *queryNodeCluster) []int64 { +func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster) []int64 { maxNumChannels := 0 nodes := make(map[int64]Node) var err error for { - nodes, err = cluster.onServiceNodes() + nodes, err = cluster.onlineNodes() if err != nil { log.Debug(err.Error()) time.Sleep(1 * time.Second) @@ -1435,12 +1435,12 @@ func shuffleChannelsToQueryNode(dmChannels []string, cluster *queryNodeCluster) } } -func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster *queryNodeCluster) []int64 { +func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster) []int64 { maxNumSegments := 0 nodes := make(map[int64]Node) var err error for { - nodes, err = cluster.onServiceNodes() + nodes, err = cluster.onlineNodes() if err != nil { log.Debug(err.Error()) time.Sleep(1 * time.Second) @@ -1526,7 +1526,7 @@ func assignInternalTask(ctx context.Context, collectionID UniqueID, parentTask task, meta Meta, - cluster *queryNodeCluster, + cluster Cluster, loadSegmentRequests []*querypb.LoadSegmentsRequest, watchDmChannelRequests []*querypb.WatchDmChannelsRequest) { diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index 6f3d585655..9ba3e73400 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -122,7 +122,7 @@ type TaskScheduler struct { triggerTaskQueue *TaskQueue activateTaskChan chan task meta Meta - cluster *queryNodeCluster + cluster Cluster taskIDAllocator func() (UniqueID, error) client *etcdkv.EtcdKV @@ -134,7 +134,7 @@ type TaskScheduler struct { cancel context.CancelFunc } -func NewTaskScheduler(ctx context.Context, meta Meta, cluster *queryNodeCluster, kv *etcdkv.EtcdKV, rootCoord types.RootCoord, dataCoord types.DataCoord) (*TaskScheduler, error) { +func NewTaskScheduler(ctx context.Context, meta Meta, cluster Cluster, kv *etcdkv.EtcdKV, rootCoord types.RootCoord, dataCoord types.DataCoord) (*TaskScheduler, error) { ctx1, cancel := context.WithCancel(ctx) taskChan := make(chan task, 1024) s := &TaskScheduler{ diff --git a/internal/querycoord/task_scheduler_test.go b/internal/querycoord/task_scheduler_test.go index 44a098b7a6..92c589285a 100644 --- a/internal/querycoord/task_scheduler_test.go +++ b/internal/querycoord/task_scheduler_test.go @@ -1,3 +1,13 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. package querycoord import ( @@ -7,17 +17,18 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/stretchr/testify/assert" ) type testTask struct { BaseTask baseMsg *commonpb.MsgBase - cluster *queryNodeCluster + cluster Cluster meta Meta nodeID int64 } @@ -108,6 +119,7 @@ func (tt *testTask) PostExecute(ctx context.Context) error { } func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) { + refreshParams() baseCtx := context.Background() queryCoord, err := startQueryCoord(baseCtx) assert.Nil(t, err) @@ -117,14 +129,12 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) { assert.Nil(t, err) queryNode.addQueryChannels = returnFailedResult - time.Sleep(time.Second) - nodes, err := queryCoord.cluster.onServiceNodes() - assert.Nil(t, err) - assert.Equal(t, len(nodes), 1) - var nodeID int64 - for id := range nodes { - nodeID = id - break + nodeID := queryNode.queryNodeID + for { + _, err = queryCoord.cluster.getNodeByID(nodeID) + if err == nil { + break + } } testTask := &testTask{ BaseTask: BaseTask{ @@ -142,16 +152,16 @@ func TestWatchQueryChannel_ClearEtcdInfoAfterAssignedNodeDown(t *testing.T) { queryCoord.scheduler.Enqueue([]task{testTask}) time.Sleep(time.Second) - queryNode.stop() - - allNodeOffline := waitAllQueryNodeOffline(queryCoord.cluster, nodes) - assert.Equal(t, allNodeOffline, true) - - time.Sleep(time.Second) - newActiveTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix) - assert.Nil(t, err) - assert.Equal(t, len(newActiveTaskIDKeys), len(activeTaskIDKeys)) + queryCoord.cluster.stopNode(nodeID) + for { + newActiveTaskIDKeys, _, err := queryCoord.scheduler.client.LoadWithPrefix(activeTaskPrefix) + assert.Nil(t, err) + if len(newActiveTaskIDKeys) == len(activeTaskIDKeys) { + break + } + } queryCoord.Stop() + queryNode.stop() } func TestUnMarshalTask(t *testing.T) { diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index e740d2381c..ef2121d35a 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -1,3 +1,13 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. package querycoord import ( @@ -11,6 +21,7 @@ import ( ) func TestTriggerTask(t *testing.T) { + refreshParams() ctx := context.Background() queryCoord, err := startQueryCoord(ctx) assert.Nil(t, err)