Add dataservice register and discovery (#5435)

Signed-off-by: sunby <bingyi.sun@zilliz.com>
pull/5779/head
sunby 2021-05-26 19:06:56 +08:00 committed by zhenshan.cao
parent e3956ad13f
commit 07c6a4a669
9 changed files with 595 additions and 949 deletions

View File

@ -11,153 +11,226 @@
package dataservice
import (
"context"
"errors"
"fmt"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"go.uber.org/zap"
"golang.org/x/net/context"
)
type dataNode struct {
id int64
address struct {
ip string
port int64
}
client types.DataNode
channelNum int
}
type dataNodeCluster struct {
sync.RWMutex
nodes []*dataNode
type cluster struct {
mu sync.RWMutex
ctx context.Context
dataManager *clusterNodeManager
sessionManager sessionManager
startupPolicy clusterStartupPolicy
registerPolicy dataNodeRegisterPolicy
unregisterPolicy dataNodeUnregisterPolicy
assginPolicy channelAssignPolicy
}
func (node *dataNode) String() string {
return fmt.Sprintf("id: %d, address: %s:%d", node.id, node.address.ip, node.address.port)
type clusterOption struct {
apply func(c *cluster)
}
func newDataNodeCluster() *dataNodeCluster {
return &dataNodeCluster{
nodes: make([]*dataNode, 0),
func withStartupPolicy(p clusterStartupPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.startupPolicy = p },
}
}
func (c *dataNodeCluster) Register(dataNode *dataNode) error {
c.Lock()
defer c.Unlock()
if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) {
c.nodes = append(c.nodes, dataNode)
return nil
func withRegisterPolicy(p dataNodeRegisterPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.registerPolicy = p },
}
return errors.New("datanode already exist")
}
func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool {
for _, node := range c.nodes {
if node.address.ip == ip && node.address.port == port {
return false
func withUnregistorPolicy(p dataNodeUnregisterPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.unregisterPolicy = p },
}
}
func withAssignPolicy(p channelAssignPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.assginPolicy = p },
}
}
func defaultStartupPolicy() clusterStartupPolicy {
return newReWatchOnRestartsStartupPolicy()
}
func defaultRegisterPolicy() dataNodeRegisterPolicy {
return newDoNothingRegisterPolicy()
}
func defaultUnregisterPolicy() dataNodeUnregisterPolicy {
return newDoNothingUnregisterPolicy()
}
func defaultAssignPolicy() channelAssignPolicy {
return newAllAssignPolicy()
}
func newCluster(ctx context.Context, dataManager *clusterNodeManager, sessionManager sessionManager, opts ...clusterOption) *cluster {
c := &cluster{
ctx: ctx,
sessionManager: sessionManager,
dataManager: dataManager,
startupPolicy: defaultStartupPolicy(),
registerPolicy: defaultRegisterPolicy(),
unregisterPolicy: defaultUnregisterPolicy(),
assginPolicy: defaultAssignPolicy(),
}
for _, opt := range opts {
opt.apply(c)
}
return c
}
func (c *cluster) startup(dataNodes []*datapb.DataNodeInfo) error {
deltaChange := c.dataManager.updateCluster(dataNodes)
nodes := c.dataManager.getDataNodes(false)
rets := c.startupPolicy.apply(nodes, deltaChange)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
return nil
}
func (c *cluster) watch(nodes []*datapb.DataNodeInfo) []*datapb.DataNodeInfo {
for _, n := range nodes {
uncompletes := make([]string, 0)
for _, ch := range n.Channels {
if ch.State == datapb.ChannelWatchState_Uncomplete {
uncompletes = append(uncompletes, ch.Name)
}
}
}
return true
}
func (c *dataNodeCluster) GetNumOfNodes() int {
c.RLock()
defer c.RUnlock()
return len(c.nodes)
}
func (c *dataNodeCluster) GetNodeIDs() []int64 {
c.RLock()
defer c.RUnlock()
ret := make([]int64, 0, len(c.nodes))
for _, node := range c.nodes {
ret = append(ret, node.id)
}
return ret
}
func (c *dataNodeCluster) WatchInsertChannels(channels []string) {
ctx := context.TODO()
c.Lock()
defer c.Unlock()
var groups [][]string
if len(channels) < len(c.nodes) {
groups = make([][]string, len(channels))
} else {
groups = make([][]string, len(c.nodes))
}
length := len(groups)
for i, channel := range channels {
groups[i%length] = append(groups[i%length], channel)
}
for i, group := range groups {
resp, err := c.nodes[i].client.WatchDmChannels(ctx, &datapb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: -1, // todo
Timestamp: 0, // todo
SourceID: Params.NodeID,
},
// ChannelNames: group, // TODO
})
if err = VerifyResponse(resp, err); err != nil {
log.Error("watch dm channels error", zap.Stringer("dataNode", c.nodes[i]), zap.Error(err))
continue
}
c.nodes[i].channelNum += len(group)
}
}
func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.ComponentInfo, error) {
c.RLock()
defer c.RUnlock()
ret := make([]*internalpb.ComponentInfo, 0)
for _, node := range c.nodes {
states, err := node.client.GetComponentStates(ctx)
cli, err := c.sessionManager.getOrCreateSession(n.Address)
if err != nil {
log.Error("get component states error", zap.Stringer("dataNode", node), zap.Error(err))
log.Warn("get session failed", zap.String("addr", n.Address), zap.Error(err))
continue
}
ret = append(ret, states.State)
req := &datapb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
SourceID: Params.NodeID,
},
//ChannelNames: uncompletes,
}
resp, err := cli.WatchDmChannels(c.ctx, req)
if err != nil {
log.Warn("watch dm channel failed", zap.String("addr", n.Address), zap.Error(err))
continue
}
if resp.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("watch channels failed", zap.String("address", n.Address), zap.Error(err))
continue
}
for _, ch := range n.Channels {
if ch.State == datapb.ChannelWatchState_Uncomplete {
ch.State = datapb.ChannelWatchState_Complete
}
}
}
return ret, nil
return nodes
}
func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) {
ctx := context.TODO()
c.Lock()
defer c.Unlock()
for _, node := range c.nodes {
if _, err := node.client.FlushSegments(ctx, request); err != nil {
log.Error("flush segment err", zap.Stringer("dataNode", node), zap.Error(err))
func (c *cluster) register(n *datapb.DataNodeInfo) {
c.mu.Lock()
defer c.mu.Unlock()
c.dataManager.register(n)
cNodes := c.dataManager.getDataNodes(true)
rets := c.registerPolicy.apply(cNodes, n)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
}
func (c *cluster) unregister(n *datapb.DataNodeInfo) {
c.mu.Lock()
defer c.mu.Unlock()
c.dataManager.unregister(n)
cNodes := c.dataManager.getDataNodes(true)
rets := c.unregisterPolicy.apply(cNodes, n)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
}
func (c *cluster) watchIfNeeded(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
cNodes := c.dataManager.getDataNodes(true)
rets := c.assginPolicy.apply(cNodes, channel)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
}
func (c *cluster) flush(segments []*datapb.SegmentInfo) {
log.Debug("prepare to flush", zap.Any("segments", segments))
c.mu.Lock()
defer c.mu.Unlock()
m := make(map[string]map[UniqueID][]UniqueID) // channel-> map[collectionID]segmentIDs
for _, seg := range segments {
if _, ok := m[seg.InsertChannel]; !ok {
m[seg.InsertChannel] = make(map[UniqueID][]UniqueID)
}
m[seg.InsertChannel][seg.CollectionID] = append(m[seg.InsertChannel][seg.CollectionID], seg.ID)
}
dataNodes := c.dataManager.getDataNodes(true)
channel2Node := make(map[string]string)
for _, node := range dataNodes {
for _, chstatus := range node.Channels {
channel2Node[chstatus.Name] = node.Address
}
}
for ch, coll2seg := range m {
node, ok := channel2Node[ch]
if !ok {
continue
}
cli, err := c.sessionManager.getOrCreateSession(node)
if err != nil {
log.Warn("get session failed", zap.String("addr", node), zap.Error(err))
continue
}
for coll, segs := range coll2seg {
req := &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Flush,
SourceID: Params.NodeID,
},
CollectionID: coll,
SegmentIDs: segs,
}
resp, err := cli.FlushSegments(c.ctx, req)
if err != nil {
log.Warn("flush segment failed", zap.String("addr", node), zap.Error(err))
continue
}
if resp.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("flush segment failed", zap.String("dataNode", node), zap.Error(err))
continue
}
log.Debug("flush segments succeed", zap.Any("segmentIDs", segs))
}
}
}
func (c *dataNodeCluster) ShutDownClients() {
c.Lock()
defer c.Unlock()
for _, node := range c.nodes {
if err := node.client.Stop(); err != nil {
log.Error("stop client error", zap.Stringer("dataNode", node), zap.Error(err))
continue
}
}
}
// Clear only for test
func (c *dataNodeCluster) Clear() {
c.Lock()
defer c.Unlock()
c.nodes = make([]*dataNode, 0)
func (c *cluster) releaseSessions() {
c.mu.Lock()
defer c.mu.Unlock()
c.sessionManager.release()
}

View File

@ -12,28 +12,34 @@ package dataservice
import (
"sync"
"time"
grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client"
"github.com/milvus-io/milvus/internal/types"
)
const retryTimes = 2
type sessionManager interface {
sendRequest(addr string, executor func(node types.DataNode) error) error
getOrCreateSession(addr string) (types.DataNode, error)
releaseSession(addr string)
release()
}
type clusterSessionManager struct {
mu sync.RWMutex
sessions map[string]types.DataNode
mu sync.RWMutex
sessions map[string]types.DataNode
dataClientCreator func(addr string, timeout time.Duration) (types.DataNode, error)
}
func newClusterSessionManager() *clusterSessionManager {
return &clusterSessionManager{sessions: make(map[string]types.DataNode)}
func newClusterSessionManager(dataClientCreator func(addr string, timeout time.Duration) (types.DataNode, error)) *clusterSessionManager {
return &clusterSessionManager{
sessions: make(map[string]types.DataNode),
dataClientCreator: dataClientCreator,
}
}
func (m *clusterSessionManager) createSession(addr string) error {
cli, err := grpcdatanodeclient.NewClient(addr, 0, []string{}, 0)
cli, err := m.dataClientCreator(addr, 0)
if err != nil {
return err
}
@ -47,8 +53,13 @@ func (m *clusterSessionManager) createSession(addr string) error {
return nil
}
func (m *clusterSessionManager) getSession(addr string) types.DataNode {
return m.sessions[addr]
func (m *clusterSessionManager) getOrCreateSession(addr string) (types.DataNode, error) {
if !m.hasSession(addr) {
if err := m.createSession(addr); err != nil {
return nil, err
}
}
return m.sessions[addr], nil
}
func (m *clusterSessionManager) hasSession(addr string) bool {
@ -56,19 +67,17 @@ func (m *clusterSessionManager) hasSession(addr string) bool {
return ok
}
func (m *clusterSessionManager) sendRequest(addr string, executor func(node types.DataNode) error) error {
m.mu.Lock()
defer m.mu.Unlock()
success := false
var err error
for i := 0; !success && i < retryTimes; i++ {
if i != 0 || !m.hasSession(addr) {
m.createSession(addr)
}
err = executor(m.getSession(addr))
if err == nil {
return nil
}
func (m *clusterSessionManager) releaseSession(addr string) {
cli, ok := m.sessions[addr]
if !ok {
return
}
_ = cli.Stop()
delete(m.sessions, addr)
}
func (m *clusterSessionManager) release() {
for _, cli := range m.sessions {
_ = cli.Stop()
}
return err
}

View File

@ -11,90 +11,133 @@
package dataservice
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/proto/internalpb"
memkv "github.com/milvus-io/milvus/internal/kv/mem"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/stretchr/testify/assert"
"golang.org/x/net/context"
)
func TestDataNodeClusterRegister(t *testing.T) {
Params.Init()
cluster := newDataNodeCluster()
dataNodeNum := 3
ids := make([]int64, 0, dataNodeNum)
for i := 0; i < dataNodeNum; i++ {
c, err := newMockDataNodeClient(int64(i))
assert.Nil(t, err)
err = c.Init()
assert.Nil(t, err)
err = c.Start()
assert.Nil(t, err)
cluster.Register(&dataNode{
id: int64(i),
address: struct {
ip string
port int64
}{"localhost", int64(9999 + i)},
client: c,
channelNum: 0,
})
ids = append(ids, int64(i))
func TestClusterCreate(t *testing.T) {
cPolicy := newMockStartupPolicy()
cluster := createCluster(t, nil, withStartupPolicy(cPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
assert.EqualValues(t, dataNodeNum, cluster.GetNumOfNodes())
assert.EqualValues(t, ids, cluster.GetNodeIDs())
states, err := cluster.GetDataNodeStates(context.TODO())
err := cluster.startup(nodes)
assert.Nil(t, err)
assert.EqualValues(t, dataNodeNum, len(states))
for _, s := range states {
assert.EqualValues(t, internalpb.StateCode_Healthy, s.StateCode)
}
cluster.ShutDownClients()
states, err = cluster.GetDataNodeStates(context.TODO())
assert.Nil(t, err)
assert.EqualValues(t, dataNodeNum, len(states))
for _, s := range states {
assert.EqualValues(t, internalpb.StateCode_Abnormal, s.StateCode)
}
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
}
func TestWatchChannels(t *testing.T) {
Params.Init()
dataNodeNum := 3
cases := []struct {
collectionID UniqueID
channels []string
channelNums []int
}{
{1, []string{"c1"}, []int{1, 0, 0}},
{1, []string{"c1", "c2", "c3"}, []int{1, 1, 1}},
{1, []string{"c1", "c2", "c3", "c4"}, []int{2, 1, 1}},
{1, []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7"}, []int{3, 2, 2}},
func TestRegister(t *testing.T) {
cPolicy := newMockStartupPolicy()
registerPolicy := newDoNothingRegisterPolicy()
cluster := createCluster(t, nil, withStartupPolicy(cPolicy), withRegisterPolicy(registerPolicy))
addr := "localhost:8080"
err := cluster.startup(nil)
assert.Nil(t, err)
cluster.register(&datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
})
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
}
func TestUnregister(t *testing.T) {
cPolicy := newMockStartupPolicy()
unregisterPolicy := newDoNothingUnregisterPolicy()
cluster := createCluster(t, nil, withStartupPolicy(cPolicy), withUnregistorPolicy(unregisterPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
err := cluster.startup(nodes)
assert.Nil(t, err)
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
cluster.unregister(&datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
})
dataNodes = cluster.dataManager.getDataNodes(false)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, offline, cluster.dataManager.dataNodes[addr].status)
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
}
func TestWatchIfNeeded(t *testing.T) {
cPolicy := newMockStartupPolicy()
cluster := createCluster(t, nil, withStartupPolicy(cPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
err := cluster.startup(nodes)
assert.Nil(t, err)
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
chName := "ch1"
cluster.watchIfNeeded(chName)
dataNodes = cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes[addr].Channels))
assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name)
cluster.watchIfNeeded(chName)
assert.EqualValues(t, 1, len(dataNodes[addr].Channels))
assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name)
}
func TestFlushSegments(t *testing.T) {
cPolicy := newMockStartupPolicy()
cluster := createCluster(t, nil, withStartupPolicy(cPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
err := cluster.startup(nodes)
assert.Nil(t, err)
segments := []*datapb.SegmentInfo{
{
ID: 0,
CollectionID: 0,
InsertChannel: "ch1",
},
}
cluster := newDataNodeCluster()
for _, c := range cases {
for i := 0; i < dataNodeNum; i++ {
c, err := newMockDataNodeClient(int64(i))
assert.Nil(t, err)
err = c.Init()
assert.Nil(t, err)
err = c.Start()
assert.Nil(t, err)
cluster.Register(&dataNode{
id: int64(i),
address: struct {
ip string
port int64
}{"localhost", int64(9999 + i)},
client: c,
channelNum: 0,
})
}
cluster.WatchInsertChannels(c.channels)
for i := 0; i < len(cluster.nodes); i++ {
assert.EqualValues(t, c.channelNums[i], cluster.nodes[i].channelNum)
}
cluster.Clear()
}
cluster.flush(segments)
}
func createCluster(t *testing.T, ch chan interface{}, options ...clusterOption) *cluster {
kv := memkv.NewMemoryKV()
sessionManager := newMockSessionManager(ch)
dataManager, err := newClusterNodeManager(kv)
assert.Nil(t, err)
return newCluster(context.TODO(), dataManager, sessionManager, options...)
}

View File

@ -1,229 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package dataservice
import (
"sync"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap"
"golang.org/x/net/context"
)
type cluster struct {
mu sync.RWMutex
dataManager *clusterNodeManager
sessionManager sessionManager
startupPolicy clusterStartupPolicy
registerPolicy dataNodeRegisterPolicy
unregisterPolicy dataNodeUnregisterPolicy
assginPolicy channelAssignPolicy
}
type clusterOption struct {
apply func(c *cluster)
}
func withStartupPolicy(p clusterStartupPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.startupPolicy = p },
}
}
func withRegisterPolicy(p dataNodeRegisterPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.registerPolicy = p },
}
}
func withUnregistorPolicy(p dataNodeUnregisterPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.unregisterPolicy = p },
}
}
func withAssignPolicy(p channelAssignPolicy) clusterOption {
return clusterOption{
apply: func(c *cluster) { c.assginPolicy = p },
}
}
func defaultStartupPolicy() clusterStartupPolicy {
return newReWatchOnRestartsStartupPolicy()
}
func defaultRegisterPolicy() dataNodeRegisterPolicy {
return newDoNothingRegisterPolicy()
}
func defaultUnregisterPolicy() dataNodeUnregisterPolicy {
return newDoNothingUnregisterPolicy()
}
func defaultAssignPolicy() channelAssignPolicy {
return newAllAssignPolicy()
}
func newCluster(dataManager *clusterNodeManager, sessionManager sessionManager, opts ...clusterOption) *cluster {
c := &cluster{
dataManager: dataManager,
sessionManager: sessionManager,
}
c.startupPolicy = defaultStartupPolicy()
c.registerPolicy = defaultRegisterPolicy()
c.unregisterPolicy = defaultUnregisterPolicy()
c.assginPolicy = defaultAssignPolicy()
for _, opt := range opts {
opt.apply(c)
}
return c
}
func (c *cluster) startup(dataNodes []*datapb.DataNodeInfo) error {
deltaChange := c.dataManager.updateCluster(dataNodes)
nodes := c.dataManager.getDataNodes(false)
rets := c.startupPolicy.apply(nodes, deltaChange)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
return nil
}
func (c *cluster) watch(nodes []*datapb.DataNodeInfo) []*datapb.DataNodeInfo {
for _, n := range nodes {
uncompletes := make([]string, 0)
for _, ch := range n.Channels {
if ch.State == datapb.ChannelWatchState_Uncomplete {
uncompletes = append(uncompletes, ch.Name)
}
}
executor := func(cli types.DataNode) error {
req := &datapb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
SourceID: Params.NodeID,
},
// ChannelNames: uncompletes, // TODO
}
resp, err := cli.WatchDmChannels(context.Background(), req)
if err != nil {
return err
}
if resp.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("watch channels failed", zap.String("address", n.Address), zap.Error(err))
return nil
}
for _, ch := range n.Channels {
if ch.State == datapb.ChannelWatchState_Uncomplete {
ch.State = datapb.ChannelWatchState_Complete
}
}
return nil
}
if err := c.sessionManager.sendRequest(n.Address, executor); err != nil {
log.Warn("watch channels failed", zap.String("address", n.Address), zap.Error(err))
}
}
return nodes
}
func (c *cluster) register(n *datapb.DataNodeInfo) {
c.mu.Lock()
defer c.mu.Unlock()
c.dataManager.register(n)
cNodes := c.dataManager.getDataNodes(true)
rets := c.registerPolicy.apply(cNodes, n)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
}
func (c *cluster) unregister(n *datapb.DataNodeInfo) {
c.mu.Lock()
defer c.mu.Unlock()
c.dataManager.unregister(n)
cNodes := c.dataManager.getDataNodes(true)
rets := c.unregisterPolicy.apply(cNodes, n)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
}
func (c *cluster) watchIfNeeded(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
cNodes := c.dataManager.getDataNodes(true)
rets := c.assginPolicy.apply(cNodes, channel)
c.dataManager.updateDataNodes(rets)
rets = c.watch(rets)
c.dataManager.updateDataNodes(rets)
}
func (c *cluster) flush(segments []*datapb.SegmentInfo) {
c.mu.Lock()
defer c.mu.Unlock()
m := make(map[string]map[UniqueID][]UniqueID) // channel-> map[collectionID]segmentIDs
for _, seg := range segments {
if _, ok := m[seg.InsertChannel]; !ok {
m[seg.InsertChannel] = make(map[UniqueID][]UniqueID)
}
m[seg.InsertChannel][seg.CollectionID] = append(m[seg.InsertChannel][seg.CollectionID], seg.ID)
}
dataNodes := c.dataManager.getDataNodes(true)
channel2Node := make(map[string]string)
for _, node := range dataNodes {
for _, chstatus := range node.Channels {
channel2Node[chstatus.Name] = node.Address
}
}
for ch, coll2seg := range m {
node, ok := channel2Node[ch]
if !ok {
continue
}
for coll, segs := range coll2seg {
executor := func(cli types.DataNode) error {
req := &datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Flush,
SourceID: Params.NodeID,
},
CollectionID: coll,
SegmentIDs: segs,
}
resp, err := cli.FlushSegments(context.Background(), req)
if err != nil {
return err
}
if resp.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("flush segment error", zap.String("dataNode", node), zap.Error(err))
}
return nil
}
if err := c.sessionManager.sendRequest(node, executor); err != nil {
log.Warn("flush segment error", zap.String("dataNode", node), zap.Error(err))
}
}
}
}

View File

@ -1,142 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package dataservice
import (
"testing"
memkv "github.com/milvus-io/milvus/internal/kv/mem"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/stretchr/testify/assert"
)
func TestClusterCreate(t *testing.T) {
cPolicy := newMockStartupPolicy()
cluster := createCluster(t, withStartupPolicy(cPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
err := cluster.startup(nodes)
assert.Nil(t, err)
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
}
func TestRegister(t *testing.T) {
cPolicy := newMockStartupPolicy()
registerPolicy := newDoNothingRegisterPolicy()
cluster := createCluster(t, withStartupPolicy(cPolicy), withRegisterPolicy(registerPolicy))
addr := "localhost:8080"
err := cluster.startup(nil)
assert.Nil(t, err)
cluster.register(&datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
})
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
}
func TestUnregister(t *testing.T) {
cPolicy := newMockStartupPolicy()
unregisterPolicy := newDoNothingUnregisterPolicy()
cluster := createCluster(t, withStartupPolicy(cPolicy), withUnregistorPolicy(unregisterPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
err := cluster.startup(nodes)
assert.Nil(t, err)
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
cluster.unregister(&datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
})
dataNodes = cluster.dataManager.getDataNodes(false)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, offline, cluster.dataManager.dataNodes[addr].status)
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
}
func TestWatchIfNeeded(t *testing.T) {
cPolicy := newMockStartupPolicy()
cluster := createCluster(t, withStartupPolicy(cPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
err := cluster.startup(nodes)
assert.Nil(t, err)
dataNodes := cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address)
chName := "ch1"
cluster.watchIfNeeded(chName)
dataNodes = cluster.dataManager.getDataNodes(true)
assert.EqualValues(t, 1, len(dataNodes[addr].Channels))
assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name)
cluster.watchIfNeeded(chName)
assert.EqualValues(t, 1, len(dataNodes[addr].Channels))
assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name)
}
func TestFlushSegments(t *testing.T) {
cPolicy := newMockStartupPolicy()
cluster := createCluster(t, withStartupPolicy(cPolicy))
addr := "localhost:8080"
nodes := []*datapb.DataNodeInfo{
{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
},
}
err := cluster.startup(nodes)
assert.Nil(t, err)
segments := []*datapb.SegmentInfo{
{
ID: 0,
CollectionID: 0,
InsertChannel: "ch1",
},
}
cluster.flush(segments)
}
func createCluster(t *testing.T, options ...clusterOption) *cluster {
kv := memkv.NewMemoryKV()
sessionManager := newMockSessionManager()
dataManager, err := newClusterNodeManager(kv)
assert.Nil(t, err)
return newCluster(dataManager, sessionManager, options...)
}

View File

@ -26,12 +26,13 @@ func (s *Server) GetComponentStates(ctx context.Context) (*internalpb.ComponentS
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}
dataNodeStates, err := s.cluster.GetDataNodeStates(ctx)
if err != nil {
resp.Status.Reason = err.Error()
return resp, nil
}
resp.SubcomponentStates = dataNodeStates
// todo GetComponentStates need to be removed
//dataNodeStates, err := s.cluster.GetDataNodeStates(ctx)
//if err != nil {
//resp.Status.Reason = err.Error()
//return resp, nil
//}
//resp.SubcomponentStates = dataNodeStates
resp.Status.ErrorCode = commonpb.ErrorCode_Success
return resp, nil
}
@ -55,58 +56,9 @@ func (s *Server) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp
}
func (s *Server) RegisterNode(ctx context.Context, req *datapb.RegisterNodeRequest) (*datapb.RegisterNodeResponse, error) {
ret := &datapb.RegisterNodeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}
log.Debug("DataService: RegisterNode:",
zap.String("IP", req.Address.Ip),
zap.Int64("Port", req.Address.Port))
node, err := s.newDataNode(req.Address.Ip, req.Address.Port, req.Base.SourceID)
if err != nil {
ret.Status.Reason = err.Error()
return ret, nil
}
resp, err := node.client.WatchDmChannels(s.ctx, &datapb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: Params.NodeID,
},
// ChannelNames: s.insertChannels, // TODO
})
if err = VerifyResponse(resp, err); err != nil {
ret.Status.Reason = err.Error()
return ret, nil
}
if err := s.getDDChannel(); err != nil {
ret.Status.Reason = err.Error()
return ret, nil
}
if err = s.cluster.Register(node); err != nil {
ret.Status.Reason = err.Error()
return ret, nil
}
ret.Status.ErrorCode = commonpb.ErrorCode_Success
ret.InitParams = &internalpb.InitParams{
NodeID: Params.NodeID,
StartParams: []*commonpb.KeyValuePair{
{Key: "DDChannelName", Value: s.ddChannelMu.name},
{Key: "SegmentStatisticsChannelName", Value: Params.StatisticsChannelName},
{Key: "TimeTickChannelName", Value: Params.TimeTickChannelName},
{Key: "CompleteFlushChannelName", Value: Params.SegmentInfoChannelName},
},
}
return ret, nil
return nil, nil
}
func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*commonpb.Status, error) {
if !s.checkStateIsHealthy() {
return &commonpb.Status{
@ -192,6 +144,7 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI
SegIDAssignments: assigns,
}, nil
}
func (s *Server) ShowSegments(ctx context.Context, req *datapb.ShowSegmentsRequest) (*datapb.ShowSegmentsResponse, error) {
resp := &datapb.ShowSegmentsResponse{
Status: &commonpb.Status{
@ -280,7 +233,7 @@ func (s *Server) GetInsertChannels(ctx context.Context, req *datapb.GetInsertCha
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Values: s.insertChannels,
Values: []string{},
}, nil
}

View File

@ -70,10 +70,11 @@ type mockDataNodeClient struct {
ch chan interface{}
}
func newMockDataNodeClient(id int64) (*mockDataNodeClient, error) {
func newMockDataNodeClient(id int64, ch chan interface{}) (*mockDataNodeClient, error) {
return &mockDataNodeClient{
id: id,
state: internalpb.StateCode_Initializing,
ch: ch,
}, nil
}
@ -301,12 +302,21 @@ func (p *mockStartupPolicy) apply(oldCluster map[string]*datapb.DataNodeInfo, de
}
type mockSessionManager struct {
ch chan interface{}
}
func newMockSessionManager() sessionManager {
return &mockSessionManager{}
func newMockSessionManager(ch chan interface{}) sessionManager {
return &mockSessionManager{
ch: ch,
}
}
func (m *mockSessionManager) sendRequest(addr string, executor func(node types.DataNode) error) error {
return nil
func (m *mockSessionManager) getOrCreateSession(addr string) (types.DataNode, error) {
return newMockDataNodeClient(0, m.ch)
}
func (m *mockSessionManager) releaseSession(addr string) {
}
func (m *mockSessionManager) release() {
}

View File

@ -12,7 +12,6 @@ package dataservice
import (
"context"
"errors"
"fmt"
"math/rand"
"strconv"
@ -20,9 +19,9 @@ import (
"sync/atomic"
"time"
grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client"
"github.com/milvus-io/milvus/internal/logutil"
grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
@ -51,39 +50,43 @@ type Server struct {
serverLoopCancel context.CancelFunc
serverLoopWg sync.WaitGroup
state atomic.Value
kvClient *etcdkv.EtcdKV
meta *meta
segAllocator segmentAllocatorInterface
statsHandler *statsHandler
allocator allocatorInterface
cluster *dataNodeCluster
masterClient types.MasterService
ddChannelMu struct {
initOnce sync.Once
startOnce sync.Once
stopOnce sync.Once
kvClient *etcdkv.EtcdKV
meta *meta
segmentInfoStream msgstream.MsgStream
segAllocator segmentAllocatorInterface
statsHandler *statsHandler
allocator allocatorInterface
cluster *cluster
masterClient types.MasterService
ddChannelMu struct {
sync.Mutex
name string
}
session *sessionutil.Session
flushMsgStream msgstream.MsgStream
insertChannels []string
msFactory msgstream.Factory
createDataNodeClient func(addr string) (types.DataNode, error)
flushMsgStream msgstream.MsgStream
msFactory msgstream.Factory
session *sessionutil.Session
activeCh <-chan bool
watchCh <-chan *sessionutil.SessionEvent
dataClientCreator func(addr string) (types.DataNode, error)
}
func CreateServer(ctx context.Context, factory msgstream.Factory) (*Server, error) {
rand.Seed(time.Now().UnixNano())
s := &Server{
ctx: ctx,
cluster: newDataNodeCluster(),
msFactory: factory,
}
s.insertChannels = s.getInsertChannels()
s.createDataNodeClient = func(addr string) (types.DataNode, error) {
node, err := grpcdatanodeclient.NewClient(addr, 10*time.Second)
if err != nil {
return nil, err
}
return node, nil
s.dataClientCreator = func(addr string) (types.DataNode, error) {
return grpcdatanodeclient.NewClient(addr)
}
s.UpdateStateCode(internalpb.StateCode_Abnormal)
log.Debug("DataService", zap.Any("State", s.state.Load()))
return s, nil
@ -104,63 +107,116 @@ func (s *Server) SetMasterClient(masterClient types.MasterService) {
// Register register data service at etcd
func (s *Server) Register() error {
s.session = sessionutil.NewSession(s.ctx, Params.MetaRootPath, []string{Params.EtcdAddress})
s.session.Init(typeutil.DataServiceRole, Params.IP, true)
s.activeCh = s.session.Init(typeutil.DataServiceRole, Params.IP, true)
Params.NodeID = s.session.ServerID
return nil
}
func (s *Server) Init() error {
s.initOnce.Do(func() {
s.session = sessionutil.NewSession(s.ctx, []string{Params.EtcdAddress})
})
return nil
}
var startOnce sync.Once
func (s *Server) Start() error {
var err error
m := map[string]interface{}{
"PulsarAddress": Params.PulsarAddress,
"ReceiveBufSize": 1024,
"PulsarBufSize": 1024}
err = s.msFactory.SetParams(m)
s.startOnce.Do(func() {
m := map[string]interface{}{
"PulsarAddress": Params.PulsarAddress,
"ReceiveBufSize": 1024,
"PulsarBufSize": 1024}
err = s.msFactory.SetParams(m)
if err != nil {
return
}
if err = s.initMeta(); err != nil {
return
}
if err = s.initCluster(); err != nil {
return
}
if err = s.initSegmentInfoChannel(); err != nil {
return
}
s.allocator = newAllocator(s.masterClient)
s.startSegmentAllocator()
s.statsHandler = newStatsHandler(s.meta)
if err = s.initFlushMsgStream(); err != nil {
return
}
if err = s.initServiceDiscovery(); err != nil {
return
}
s.startServerLoop()
s.UpdateStateCode(internalpb.StateCode_Healthy)
log.Debug("start success")
})
return err
}
func (s *Server) initCluster() error {
dManager, err := newClusterNodeManager(s.kvClient)
if err != nil {
return err
}
sManager := newClusterSessionManager(s.dataClientCreator)
s.cluster = newCluster(s.ctx, dManager, sManager)
return nil
}
if err := s.initMeta(); err != nil {
func (s *Server) initServiceDiscovery() error {
sessions, rev, err := s.session.GetSessions(typeutil.DataNodeRole)
if err != nil {
log.Debug("DataService initMeta failed", zap.Error(err))
return err
}
log.Debug("registered sessions", zap.Any("sessions", sessions))
s.allocator = newAllocator(s.masterClient)
datanodes := make([]*datapb.DataNodeInfo, 0, len(sessions))
for _, session := range sessions {
datanodes = append(datanodes, &datapb.DataNodeInfo{
Address: session.Address,
Version: session.ServerID,
Channels: []*datapb.ChannelStatus{},
})
}
s.startSegmentAllocator()
s.statsHandler = newStatsHandler(s.meta)
if err = s.loadMetaFromMaster(); err != nil {
if err := s.cluster.startup(datanodes); err != nil {
log.Debug("DataService loadMetaFromMaster failed", zap.Error(err))
return err
}
if err = s.initMsgProducer(); err != nil {
log.Debug("DataService initMsgProducer failed", zap.Error(err))
return err
}
s.startServerLoop()
s.UpdateStateCode(internalpb.StateCode_Healthy)
log.Debug("start success")
log.Debug("DataService", zap.Any("State", s.state.Load()))
s.watchCh = s.session.WatchServices(typeutil.DataNodeRole, rev)
return nil
}
func (s *Server) startSegmentAllocator() {
stream := s.initSegmentInfoChannel()
helper := createNewSegmentHelper(stream)
helper := createNewSegmentHelper(s.segmentInfoStream)
s.segAllocator = newSegmentAllocator(s.meta, s.allocator, withAllocHelper(helper))
}
func (s *Server) initSegmentInfoChannel() msgstream.MsgStream {
segmentInfoStream, _ := s.msFactory.NewMsgStream(s.ctx)
segmentInfoStream.AsProducer([]string{Params.SegmentInfoChannelName})
func (s *Server) initSegmentInfoChannel() error {
var err error
s.segmentInfoStream, err = s.msFactory.NewMsgStream(s.ctx)
if err != nil {
return err
}
s.segmentInfoStream.AsProducer([]string{Params.SegmentInfoChannelName})
log.Debug("DataService AsProducer: " + Params.SegmentInfoChannelName)
segmentInfoStream.Start()
return segmentInfoStream
s.segmentInfoStream.Start()
return nil
}
func (s *Server) UpdateStateCode(code internalpb.StateCode) {
@ -189,7 +245,7 @@ func (s *Server) initMeta() error {
return retry.Retry(100000, time.Millisecond*200, connectEtcdFn)
}
func (s *Server) initMsgProducer() error {
func (s *Server) initFlushMsgStream() error {
var err error
// segment flush stream
s.flushMsgStream, err = s.msFactory.NewMsgStream(s.ctx)
@ -203,72 +259,6 @@ func (s *Server) initMsgProducer() error {
return nil
}
func (s *Server) loadMetaFromMaster() error {
ctx := context.Background()
log.Debug("loading collection meta from master")
var err error
if err = s.checkMasterIsHealthy(); err != nil {
return err
}
if err = s.getDDChannel(); err != nil {
return err
}
collections, err := s.masterClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: -1, // todo add msg id
Timestamp: 0, // todo
SourceID: Params.NodeID,
},
DbName: "",
})
if err = VerifyResponse(collections, err); err != nil {
return err
}
for _, collectionName := range collections.CollectionNames {
collection, err := s.masterClient.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: -1, // todo
Timestamp: 0, // todo
SourceID: Params.NodeID,
},
DbName: "",
CollectionName: collectionName,
})
if err = VerifyResponse(collection, err); err != nil {
log.Error("describe collection error", zap.String("collectionName", collectionName), zap.Error(err))
continue
}
partitions, err := s.masterClient.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowPartitions,
MsgID: -1, // todo
Timestamp: 0, // todo
SourceID: Params.NodeID,
},
DbName: "",
CollectionName: collectionName,
CollectionID: collection.CollectionID,
})
if err = VerifyResponse(partitions, err); err != nil {
log.Error("show partitions error", zap.String("collectionName", collectionName), zap.Int64("collectionID", collection.CollectionID), zap.Error(err))
continue
}
err = s.meta.AddCollection(&datapb.CollectionInfo{
ID: collection.CollectionID,
Schema: collection.Schema,
Partitions: partitions.PartitionIDs,
})
if err != nil {
log.Error("add collection to meta error", zap.Int64("collectionID", collection.CollectionID), zap.Error(err))
continue
}
}
log.Debug("load collection meta from master complete")
return nil
}
func (s *Server) getDDChannel() error {
s.ddChannelMu.Lock()
defer s.ddChannelMu.Unlock()
@ -282,37 +272,13 @@ func (s *Server) getDDChannel() error {
return nil
}
func (s *Server) checkMasterIsHealthy() error {
ticker := time.NewTicker(300 * time.Millisecond)
ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second)
defer func() {
ticker.Stop()
cancel()
}()
for {
var resp *internalpb.ComponentStates
var err error
select {
case <-ctx.Done():
return errors.New("master is not healthy")
case <-ticker.C:
resp, err = s.masterClient.GetComponentStates(ctx)
if err = VerifyResponse(resp, err); err != nil {
return err
}
}
if resp.State.StateCode == internalpb.StateCode_Healthy {
break
}
}
return nil
}
func (s *Server) startServerLoop() {
s.serverLoopCtx, s.serverLoopCancel = context.WithCancel(s.ctx)
s.serverLoopWg.Add(2)
s.serverLoopWg.Add(4)
go s.startStatsChannel(s.serverLoopCtx)
go s.startDataNodeTtLoop(s.serverLoopCtx)
go s.startWatchService(s.serverLoopCtx)
go s.startActiveCheck(s.serverLoopCtx)
}
func (s *Server) startStatsChannel(ctx context.Context) {
@ -404,7 +370,6 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) {
}
ttMsg := msg.(*msgstream.DataNodeTtMsg)
coll2Segs := make(map[UniqueID][]UniqueID)
ch := ttMsg.ChannelName
ts := ttMsg.Timestamp
segments, err := s.segAllocator.GetFlushableSegments(ctx, ch, ts)
@ -412,6 +377,9 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) {
log.Warn("get flushable segments failed", zap.Error(err))
continue
}
log.Debug("flushable segments", zap.Any("segments", segments))
segmentInfos := make([]*datapb.SegmentInfo, 0, len(segments))
for _, id := range segments {
sInfo, err := s.meta.GetSegment(id)
if err != nil {
@ -419,35 +387,74 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) {
zap.Error(err))
continue
}
collID, segID := sInfo.CollectionID, sInfo.ID
coll2Segs[collID] = append(coll2Segs[collID], segID)
segmentInfos = append(segmentInfos, sInfo)
}
for collID, segIDs := range coll2Segs {
s.cluster.FlushSegment(&datapb.FlushSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Flush,
MsgID: -1, // todo add msg id
Timestamp: 0, // todo
SourceID: Params.NodeID,
},
CollectionID: collID,
SegmentIDs: segIDs,
})
s.cluster.flush(segmentInfos)
}
}
}
func (s *Server) startWatchService(ctx context.Context) {
defer s.serverLoopWg.Done()
for {
select {
case <-ctx.Done():
log.Debug("watch service shutdown")
return
case event := <-s.watchCh:
datanode := &datapb.DataNodeInfo{
Address: event.Session.Address,
Version: event.Session.ServerID,
Channels: []*datapb.ChannelStatus{},
}
switch event.EventType {
case sessionutil.SessionAddEvent:
s.cluster.register(datanode)
case sessionutil.SessionDelEvent:
s.cluster.unregister(datanode)
default:
log.Warn("receive unknown service event type",
zap.Any("type", event.EventType))
}
}
}
}
func (s *Server) startActiveCheck(ctx context.Context) {
defer s.serverLoopWg.Done()
for {
select {
case _, ok := <-s.activeCh:
if ok {
continue
}
s.Stop()
log.Debug("disconnect with etcd")
return
case <-ctx.Done():
log.Debug("connection check shutdown")
return
}
}
}
var stopOnce sync.Once
func (s *Server) Stop() error {
s.cluster.ShutDownClients()
s.flushMsgStream.Close()
s.stopServerLoop()
s.stopOnce.Do(func() {
s.cluster.releaseSessions()
s.segmentInfoStream.Close()
s.flushMsgStream.Close()
s.stopServerLoop()
})
return nil
}
// CleanMeta only for test
func (s *Server) CleanMeta() error {
log.Debug("clean meta", zap.Any("kv", s.kvClient))
return s.kvClient.RemoveWithPrefix("")
}
@ -456,29 +463,6 @@ func (s *Server) stopServerLoop() {
s.serverLoopWg.Wait()
}
func (s *Server) newDataNode(ip string, port int64, id UniqueID) (*dataNode, error) {
client, err := s.createDataNodeClient(fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return nil, err
}
if err := client.Init(); err != nil {
return nil, err
}
if err := client.Start(); err != nil {
return nil, err
}
return &dataNode{
id: id,
address: struct {
ip string
port int64
}{ip: ip, port: port},
client: client,
channelNum: 0,
}, nil
}
//func (s *Server) validateAllocRequest(collID UniqueID, partID UniqueID, channelName string) error {
// if !s.meta.HasCollection(collID) {
// return fmt.Errorf("can not find collection %d", collID)

View File

@ -32,32 +32,8 @@ import (
"go.uber.org/zap"
)
func TestRegisterNode(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
t.Run("register node", func(t *testing.T) {
resp, err := svr.RegisterNode(context.TODO(), &datapb.RegisterNodeRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 1000,
},
Address: &commonpb.Address{
Ip: "localhost",
Port: 1000,
},
})
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, 1, svr.cluster.GetNumOfNodes())
assert.EqualValues(t, []int64{1000}, svr.cluster.GetNodeIDs())
})
}
func TestGetSegmentInfoChannel(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
t.Run("get segment info channel", func(t *testing.T) {
resp, err := svr.GetSegmentInfoChannel(context.TODO())
@ -67,26 +43,6 @@ func TestGetSegmentInfoChannel(t *testing.T) {
})
}
func TestGetInsertChannels(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
t.Run("get insert channels", func(t *testing.T) {
resp, err := svr.GetInsertChannels(context.TODO(), &datapb.GetInsertChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 1000,
},
DbID: 0,
CollectionID: 0,
})
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, svr.getInsertChannels(), resp.Values)
})
}
func TestAssignSegmentID(t *testing.T) {
const collID = 100
const collIDInvalid = 101
@ -94,7 +50,7 @@ func TestAssignSegmentID(t *testing.T) {
const channel0 = "channel0"
const channel1 = "channel1"
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
schema := newTestSchema()
svr.meta.AddCollection(&datapb.CollectionInfo{
@ -151,7 +107,7 @@ func TestAssignSegmentID(t *testing.T) {
}
func TestShowSegments(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
segments := []struct {
id UniqueID
@ -202,7 +158,7 @@ func TestShowSegments(t *testing.T) {
}
func TestFlush(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
schema := newTestSchema()
err := svr.meta.AddCollection(&datapb.CollectionInfo{
@ -231,40 +187,39 @@ func TestFlush(t *testing.T) {
assert.EqualValues(t, segID, ids[0])
}
func TestGetComponentStates(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
cli, err := newMockDataNodeClient(1)
assert.Nil(t, err)
err = cli.Init()
assert.Nil(t, err)
err = cli.Start()
assert.Nil(t, err)
//func TestGetComponentStates(t *testing.T) {
//svr := newTestServer(t)
//defer closeTestServer(t, svr)
//cli := newMockDataNodeClient(1)
//err := cli.Init()
//assert.Nil(t, err)
//err = cli.Start()
//assert.Nil(t, err)
err = svr.cluster.Register(&dataNode{
id: 1,
address: struct {
ip string
port int64
}{
ip: "",
port: 0,
},
client: cli,
channelNum: 0,
})
assert.Nil(t, err)
//err = svr.cluster.Register(&dataNode{
//id: 1,
//address: struct {
//ip string
//port int64
//}{
//ip: "",
//port: 0,
//},
//client: cli,
//channelNum: 0,
//})
//assert.Nil(t, err)
resp, err := svr.GetComponentStates(context.TODO())
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, internalpb.StateCode_Healthy, resp.State.StateCode)
assert.EqualValues(t, 1, len(resp.SubcomponentStates))
assert.EqualValues(t, internalpb.StateCode_Healthy, resp.SubcomponentStates[0].StateCode)
}
//resp, err := svr.GetComponentStates(context.TODO())
//assert.Nil(t, err)
//assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
//assert.EqualValues(t, internalpb.StateCode_Healthy, resp.State.StateCode)
//assert.EqualValues(t, 1, len(resp.SubcomponentStates))
//assert.EqualValues(t, internalpb.StateCode_Healthy, resp.SubcomponentStates[0].StateCode)
//}
func TestGetTimeTickChannel(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
resp, err := svr.GetTimeTickChannel(context.TODO())
assert.Nil(t, err)
@ -273,7 +228,7 @@ func TestGetTimeTickChannel(t *testing.T) {
}
func TestGetStatisticsChannel(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
resp, err := svr.GetStatisticsChannel(context.TODO())
assert.Nil(t, err)
@ -282,7 +237,7 @@ func TestGetStatisticsChannel(t *testing.T) {
}
func TestGetSegmentStates(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
err := svr.meta.AddSegment(&datapb.SegmentInfo{
ID: 1000,
@ -339,7 +294,7 @@ func TestGetSegmentStates(t *testing.T) {
}
func TestGetInsertBinlogPaths(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
req := &datapb.GetInsertBinlogPathsRequest{
@ -351,7 +306,7 @@ func TestGetInsertBinlogPaths(t *testing.T) {
}
func TestGetCollectionStatistics(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
req := &datapb.GetCollectionStatisticsRequest{
@ -363,7 +318,7 @@ func TestGetCollectionStatistics(t *testing.T) {
}
func TestGetSegmentInfo(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
segInfo := &datapb.SegmentInfo{
@ -380,7 +335,7 @@ func TestGetSegmentInfo(t *testing.T) {
}
func TestChannel(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
t.Run("Test StatsChannel", func(t *testing.T) {
@ -491,7 +446,7 @@ func TestChannel(t *testing.T) {
}
func TestSaveBinlogPaths(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
collections := []struct {
@ -613,7 +568,8 @@ func TestSaveBinlogPaths(t *testing.T) {
}
func TestDataNodeTtChannel(t *testing.T) {
svr := newTestServer(t)
ch := make(chan interface{}, 1)
svr := newTestServer(t, ch)
defer closeTestServer(t, svr)
svr.meta.AddCollection(&datapb.CollectionInfo{
@ -622,14 +578,6 @@ func TestDataNodeTtChannel(t *testing.T) {
Partitions: []int64{0},
})
ch := make(chan interface{}, 1)
svr.createDataNodeClient = func(addr string, serverID int64) (types.DataNode, error) {
cli, err := newMockDataNodeClient(0)
assert.Nil(t, err)
cli.ch = ch
return cli, nil
}
ttMsgStream, err := svr.msFactory.NewMsgStream(context.TODO())
assert.Nil(t, err)
ttMsgStream.AsProducer([]string{Params.TimeTickChannelName})
@ -654,20 +602,16 @@ func TestDataNodeTtChannel(t *testing.T) {
}
}
resp, err := svr.RegisterNode(context.TODO(), &datapb.RegisterNodeRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Address: &commonpb.Address{
Ip: "localhost:7777",
Port: 8080,
svr.cluster.register(&datapb.DataNodeInfo{
Address: "localhost:7777",
Version: 0,
Channels: []*datapb.ChannelStatus{
{
Name: "ch-1",
State: datapb.ChannelWatchState_Complete,
},
},
})
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
t.Run("Test segment flush after tt", func(t *testing.T) {
resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
@ -688,6 +632,7 @@ func TestDataNodeTtChannel(t *testing.T) {
assert.EqualValues(t, 1, len(resp.SegIDAssignments))
assign := resp.SegIDAssignments[0]
log.Debug("xxxxxxxxxxxxx", zap.Any("assign", assign))
resp2, err := svr.Flush(context.TODO(), &datapb.FlushRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Flush,
@ -720,7 +665,7 @@ func TestResumeChannel(t *testing.T) {
segmentIDs := make([]int64, 0, 1000)
t.Run("Prepare Resume test set", func(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer svr.Stop()
i := int64(-1)
@ -743,7 +688,7 @@ func TestResumeChannel(t *testing.T) {
})
t.Run("Test ResumeSegmentStatsChannel", func(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
segRows := rand.Int63n(1000)
@ -792,7 +737,7 @@ func TestResumeChannel(t *testing.T) {
svr.Stop()
time.Sleep(time.Millisecond * 50)
svr = newTestServer(t)
svr = newTestServer(t, nil)
defer svr.Stop()
<-ch
@ -812,7 +757,7 @@ func TestResumeChannel(t *testing.T) {
})
t.Run("Clean up test segments", func(t *testing.T) {
svr := newTestServer(t)
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
var err error
for _, segID := range segmentIDs {
@ -822,7 +767,7 @@ func TestResumeChannel(t *testing.T) {
})
}
func newTestServer(t *testing.T) *Server {
func newTestServer(t *testing.T, receiveCh chan interface{}) *Server {
Params.Init()
var err error
factory := msgstream.NewPmsFactory()
@ -849,8 +794,8 @@ func newTestServer(t *testing.T) *Server {
assert.Nil(t, err)
defer ms.Stop()
svr.SetMasterClient(ms)
svr.createDataNodeClient = func(addr string) (types.DataNode, error) {
return newMockDataNodeClient(0)
svr.dataClientCreator = func(addr string) (types.DataNode, error) {
return newMockDataNodeClient(0, receiveCh)
}
assert.Nil(t, err)
err = svr.Init()