mirror of https://github.com/milvus-io/milvus.git
1093 lines
32 KiB
Go
1093 lines
32 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package querynode
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"runtime"
|
|
"sync"
|
|
|
|
"go.uber.org/atomic"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
|
"github.com/milvus-io/milvus/internal/common"
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
)
|
|
|
|
type shardClusterState int32
|
|
|
|
const (
|
|
available shardClusterState = 1
|
|
unavailable shardClusterState = 2
|
|
)
|
|
|
|
type nodeEventType int32
|
|
|
|
const (
|
|
nodeAdd nodeEventType = 1
|
|
nodeDel nodeEventType = 2
|
|
)
|
|
|
|
type segmentEventType int32
|
|
|
|
const (
|
|
segmentAdd segmentEventType = 1
|
|
segmentDel segmentEventType = 2
|
|
)
|
|
|
|
type segmentState int32
|
|
|
|
const (
|
|
segmentStateNone segmentState = 0
|
|
segmentStateOffline segmentState = 1
|
|
segmentStateLoading segmentState = 2
|
|
segmentStateLoaded segmentState = 3
|
|
)
|
|
|
|
type nodeEvent struct {
|
|
eventType nodeEventType
|
|
nodeID int64
|
|
nodeAddr string
|
|
isLeader bool
|
|
}
|
|
|
|
type segmentEvent struct {
|
|
eventType segmentEventType
|
|
segmentID int64
|
|
partitionID int64
|
|
nodeIDs []int64 // nodes from events
|
|
state segmentState
|
|
}
|
|
|
|
type shardQueryNode interface {
|
|
GetStatistics(context.Context, *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)
|
|
Search(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)
|
|
Query(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
|
|
LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) (*commonpb.Status, error)
|
|
ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
|
|
Stop() error
|
|
}
|
|
|
|
type shardNode struct {
|
|
nodeID int64
|
|
nodeAddr string
|
|
client shardQueryNode
|
|
}
|
|
|
|
type shardSegmentInfo struct {
|
|
segmentID int64
|
|
partitionID int64
|
|
nodeID int64
|
|
state segmentState
|
|
version int64
|
|
inUse int32
|
|
}
|
|
|
|
// Closable interface for close.
|
|
type Closable interface {
|
|
Close()
|
|
}
|
|
|
|
// ShardNodeDetector provides method to detect node events
|
|
type ShardNodeDetector interface {
|
|
Closable
|
|
watchNodes(collectionID int64, replicaID int64, vchannelName string) ([]nodeEvent, <-chan nodeEvent)
|
|
}
|
|
|
|
// ShardSegmentDetector provides method to detect segment events
|
|
type ShardSegmentDetector interface {
|
|
Closable
|
|
watchSegments(collectionID int64, replicaID int64, vchannelName string) ([]segmentEvent, <-chan segmentEvent)
|
|
}
|
|
|
|
// ShardNodeBuilder function type to build types.QueryNode from addr and id
|
|
type ShardNodeBuilder func(nodeID int64, addr string) shardQueryNode
|
|
|
|
// withStreaming function type to let search detects corresponding search streaming is done.
|
|
type withStreaming func(ctx context.Context) error
|
|
|
|
// ShardCluster maintains the ShardCluster information and perform shard level operations
|
|
type ShardCluster struct {
|
|
state *atomic.Int32
|
|
|
|
collectionID int64
|
|
replicaID int64
|
|
vchannelName string
|
|
version int64
|
|
|
|
nodeDetector ShardNodeDetector
|
|
segmentDetector ShardSegmentDetector
|
|
nodeBuilder ShardNodeBuilder
|
|
|
|
mut sync.RWMutex
|
|
leader *shardNode // shard leader node instance
|
|
nodes map[int64]*shardNode // online nodes
|
|
segments SegmentsStatus // shard segments
|
|
|
|
mutVersion sync.RWMutex
|
|
versions sync.Map // version id to version
|
|
currentVersion *ShardClusterVersion // current serving segment state version
|
|
nextVersionID *atomic.Int64
|
|
segmentCond *sync.Cond // segment state change condition
|
|
|
|
closeOnce sync.Once
|
|
closeCh chan struct{}
|
|
}
|
|
|
|
// NewShardCluster create a ShardCluster with provided information.
|
|
func NewShardCluster(collectionID int64, replicaID int64, vchannelName string, version int64,
|
|
nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder) *ShardCluster {
|
|
sc := &ShardCluster{
|
|
state: atomic.NewInt32(int32(unavailable)),
|
|
|
|
collectionID: collectionID,
|
|
replicaID: replicaID,
|
|
vchannelName: vchannelName,
|
|
version: version,
|
|
|
|
segmentDetector: segmentDetector,
|
|
nodeDetector: nodeDetector,
|
|
nodeBuilder: nodeBuilder,
|
|
|
|
nodes: make(map[int64]*shardNode),
|
|
segments: make(map[int64]shardSegmentInfo),
|
|
nextVersionID: atomic.NewInt64(0),
|
|
|
|
closeCh: make(chan struct{}),
|
|
}
|
|
|
|
m := sync.Mutex{}
|
|
sc.segmentCond = sync.NewCond(&m)
|
|
|
|
sc.init()
|
|
|
|
return sc
|
|
}
|
|
|
|
func (sc *ShardCluster) Close() {
|
|
log := sc.getLogger()
|
|
log.Info("Close shard cluster")
|
|
sc.closeOnce.Do(func() {
|
|
sc.updateShardClusterState(unavailable)
|
|
if sc.nodeDetector != nil {
|
|
sc.nodeDetector.Close()
|
|
}
|
|
if sc.segmentDetector != nil {
|
|
sc.segmentDetector.Close()
|
|
}
|
|
|
|
close(sc.closeCh)
|
|
})
|
|
}
|
|
|
|
func (sc *ShardCluster) getLogger() *zap.Logger {
|
|
return log.With(zap.Int64("collectionID", sc.collectionID),
|
|
zap.String("channel", sc.vchannelName),
|
|
zap.Int64("replicaID", sc.replicaID))
|
|
}
|
|
|
|
// serviceable returns whether shard cluster could provide query service.
|
|
func (sc *ShardCluster) serviceable() bool {
|
|
// all segment in loaded state
|
|
if sc.state.Load() != int32(available) {
|
|
return false
|
|
}
|
|
|
|
sc.mutVersion.RLock()
|
|
defer sc.mutVersion.RUnlock()
|
|
// check there is a working version(SyncSegments called)
|
|
return sc.currentVersion != nil
|
|
}
|
|
|
|
// addNode add a node into cluster
|
|
func (sc *ShardCluster) addNode(evt nodeEvent) {
|
|
log := sc.getLogger()
|
|
log.Info("ShardCluster add node", zap.Int64("nodeID", evt.nodeID))
|
|
sc.mut.Lock()
|
|
defer sc.mut.Unlock()
|
|
|
|
oldNode, ok := sc.nodes[evt.nodeID]
|
|
if ok {
|
|
if oldNode.nodeAddr == evt.nodeAddr {
|
|
log.Warn("ShardCluster add same node, skip", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr))
|
|
return
|
|
}
|
|
defer oldNode.client.Stop()
|
|
}
|
|
|
|
node := &shardNode{
|
|
nodeID: evt.nodeID,
|
|
nodeAddr: evt.nodeAddr,
|
|
client: sc.nodeBuilder(evt.nodeID, evt.nodeAddr),
|
|
}
|
|
sc.nodes[evt.nodeID] = node
|
|
if evt.isLeader {
|
|
sc.leader = node
|
|
}
|
|
}
|
|
|
|
// removeNode handles node offline and setup related segments
|
|
func (sc *ShardCluster) removeNode(evt nodeEvent) {
|
|
log := sc.getLogger()
|
|
log.Info("ShardCluster remove node", zap.Int64("nodeID", evt.nodeID))
|
|
sc.mut.Lock()
|
|
defer sc.mut.Unlock()
|
|
|
|
old, ok := sc.nodes[evt.nodeID]
|
|
if !ok {
|
|
log.Warn("ShardCluster removeNode does not belong to it", zap.Int64("nodeID", evt.nodeID), zap.String("addr", evt.nodeAddr))
|
|
return
|
|
}
|
|
|
|
defer old.client.Stop()
|
|
delete(sc.nodes, evt.nodeID)
|
|
|
|
for id, segment := range sc.segments {
|
|
if segment.nodeID == evt.nodeID {
|
|
segment.state = segmentStateOffline
|
|
segment.version = -1
|
|
sc.segments[id] = segment
|
|
sc.updateShardClusterState(unavailable)
|
|
}
|
|
}
|
|
// ignore leader process here
|
|
}
|
|
|
|
// updateSegment apply segment change to shard cluster
|
|
func (sc *ShardCluster) updateSegment(evt shardSegmentInfo) {
|
|
log := sc.getLogger()
|
|
log.Info("ShardCluster update segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
|
|
// notify handoff wait online if any
|
|
defer func() {
|
|
sc.segmentCond.Broadcast()
|
|
}()
|
|
|
|
sc.mut.Lock()
|
|
defer sc.mut.Unlock()
|
|
|
|
old, ok := sc.segments[evt.segmentID]
|
|
if !ok { // newly add
|
|
sc.segments[evt.segmentID] = evt
|
|
return
|
|
}
|
|
|
|
sc.transferSegment(old, evt)
|
|
}
|
|
|
|
// SetupFirstVersion initialized first version for shard cluster.
|
|
func (sc *ShardCluster) SetupFirstVersion() {
|
|
sc.mutVersion.Lock()
|
|
defer sc.mutVersion.Unlock()
|
|
version := NewShardClusterVersion(sc.nextVersionID.Inc(), make(SegmentsStatus), nil)
|
|
sc.versions.Store(version.versionID, version)
|
|
sc.currentVersion = version
|
|
}
|
|
|
|
// SyncSegments synchronize segment distribution in batch
|
|
func (sc *ShardCluster) SyncSegments(distribution []*querypb.ReplicaSegmentsInfo, state segmentState) {
|
|
log := sc.getLogger()
|
|
log.Info("ShardCluster sync segments", zap.Any("replica segments", distribution), zap.Int32("state", int32(state)))
|
|
|
|
var currentVersion *ShardClusterVersion
|
|
sc.mutVersion.RLock()
|
|
currentVersion = sc.currentVersion
|
|
sc.mutVersion.RUnlock()
|
|
if currentVersion == nil {
|
|
log.Warn("received SyncSegments call before version setup")
|
|
return
|
|
}
|
|
|
|
sc.mut.Lock()
|
|
for _, line := range distribution {
|
|
for i, segmentID := range line.GetSegmentIds() {
|
|
nodeID := line.GetNodeId()
|
|
version := line.GetVersions()[i]
|
|
// if node id not in replica node list, this line shall be placeholder for segment offline
|
|
_, ok := sc.nodes[nodeID]
|
|
if !ok {
|
|
log.Warn("Sync segment with invalid nodeID", zap.Int64("segmentID", segmentID), zap.Int64("nodeID", line.NodeId))
|
|
nodeID = common.InvalidNodeID
|
|
}
|
|
|
|
old, ok := sc.segments[segmentID]
|
|
if !ok { // newly add
|
|
sc.segments[segmentID] = shardSegmentInfo{
|
|
nodeID: nodeID,
|
|
partitionID: line.GetPartitionId(),
|
|
segmentID: segmentID,
|
|
state: state,
|
|
version: version,
|
|
}
|
|
continue
|
|
}
|
|
|
|
sc.transferSegment(old, shardSegmentInfo{
|
|
nodeID: nodeID,
|
|
partitionID: line.GetPartitionId(),
|
|
segmentID: segmentID,
|
|
state: state,
|
|
version: version,
|
|
})
|
|
}
|
|
}
|
|
|
|
// allocations := sc.segments.Clone(filterNothing)
|
|
sc.mut.Unlock()
|
|
|
|
// notify handoff wait online if any
|
|
sc.segmentCond.Broadcast()
|
|
|
|
sc.mutVersion.Lock()
|
|
defer sc.mutVersion.Unlock()
|
|
|
|
// update shardleader allocation view
|
|
allocations := sc.currentVersion.segments.Clone(filterNothing)
|
|
for _, line := range distribution {
|
|
for _, segmentID := range line.GetSegmentIds() {
|
|
allocations[segmentID] = shardSegmentInfo{nodeID: line.GetNodeId(), segmentID: segmentID, partitionID: line.GetPartitionId(), state: state}
|
|
}
|
|
}
|
|
|
|
version := NewShardClusterVersion(sc.nextVersionID.Inc(), allocations, sc.currentVersion)
|
|
sc.versions.Store(version.versionID, version)
|
|
sc.currentVersion = version
|
|
}
|
|
|
|
// transferSegment apply segment state transition.
|
|
// old\new | Offline | Loading | Loaded
|
|
// Offline | OK | OK | OK
|
|
// Loading | OK | OK | NodeID check
|
|
// Loaded | OK | OK | legacy pending
|
|
func (sc *ShardCluster) transferSegment(old shardSegmentInfo, evt shardSegmentInfo) {
|
|
log := sc.getLogger()
|
|
switch old.state {
|
|
case segmentStateOffline: // safe to update nodeID and state
|
|
old.nodeID = evt.nodeID
|
|
old.state = evt.state
|
|
old.version = evt.version
|
|
sc.segments[old.segmentID] = old
|
|
if evt.state == segmentStateLoaded {
|
|
sc.healthCheck()
|
|
}
|
|
case segmentStateLoading: // to Loaded only when nodeID equal
|
|
if evt.state == segmentStateLoaded && evt.nodeID != old.nodeID {
|
|
log.Warn("transferSegment to loaded failed, nodeID not match", zap.Int64("segmentID", evt.segmentID), zap.Int64("nodeID", old.nodeID), zap.Int64("evtNodeID", evt.nodeID))
|
|
return
|
|
}
|
|
old.nodeID = evt.nodeID
|
|
old.state = evt.state
|
|
old.version = evt.version
|
|
sc.segments[old.segmentID] = old
|
|
if evt.state == segmentStateLoaded {
|
|
sc.healthCheck()
|
|
}
|
|
case segmentStateLoaded:
|
|
// load balance
|
|
old.nodeID = evt.nodeID
|
|
old.state = evt.state
|
|
old.version = evt.version
|
|
sc.segments[old.segmentID] = old
|
|
if evt.state != segmentStateLoaded {
|
|
sc.healthCheck()
|
|
}
|
|
}
|
|
}
|
|
|
|
// removeSegment removes segment from cluster
|
|
// should only applied in hand-off or load balance procedure
|
|
func (sc *ShardCluster) removeSegment(evt shardSegmentInfo) {
|
|
log := sc.getLogger()
|
|
log.Info("ShardCluster remove segment", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID), zap.Int32("state", int32(evt.state)))
|
|
|
|
sc.mut.Lock()
|
|
defer sc.mut.Unlock()
|
|
|
|
old, ok := sc.segments[evt.segmentID]
|
|
if !ok {
|
|
log.Warn("ShardCluster removeSegment does not belong to it", zap.Int64("nodeID", evt.nodeID), zap.Int64("segmentID", evt.segmentID))
|
|
return
|
|
}
|
|
|
|
if old.nodeID != evt.nodeID {
|
|
log.Warn("ShardCluster removeSegment found node not match", zap.Int64("segmentID", evt.segmentID), zap.Int64("nodeID", old.nodeID), zap.Int64("evtNodeID", evt.nodeID))
|
|
return
|
|
}
|
|
|
|
delete(sc.segments, evt.segmentID)
|
|
sc.healthCheck()
|
|
}
|
|
|
|
// init list all nodes and semgent states ant start watching
|
|
func (sc *ShardCluster) init() {
|
|
// list nodes
|
|
nodes, nodeEvtCh := sc.nodeDetector.watchNodes(sc.collectionID, sc.replicaID, sc.vchannelName)
|
|
for _, node := range nodes {
|
|
sc.addNode(node)
|
|
}
|
|
go sc.watchNodes(nodeEvtCh)
|
|
|
|
// list segments
|
|
segments, segmentEvtCh := sc.segmentDetector.watchSegments(sc.collectionID, sc.replicaID, sc.vchannelName)
|
|
for _, segment := range segments {
|
|
info, ok := sc.pickNode(segment)
|
|
if ok {
|
|
sc.updateSegment(info)
|
|
}
|
|
}
|
|
go sc.watchSegments(segmentEvtCh)
|
|
|
|
sc.healthCheck()
|
|
}
|
|
|
|
// pickNode selects node in the cluster
|
|
func (sc *ShardCluster) pickNode(evt segmentEvent) (shardSegmentInfo, bool) {
|
|
nodeID, has := sc.selectNodeInReplica(evt.nodeIDs)
|
|
if has { // assume one segment shall exist once in one replica
|
|
return shardSegmentInfo{
|
|
segmentID: evt.segmentID,
|
|
partitionID: evt.partitionID,
|
|
nodeID: nodeID,
|
|
state: evt.state,
|
|
}, true
|
|
}
|
|
|
|
return shardSegmentInfo{}, false
|
|
}
|
|
|
|
// selectNodeInReplica returns first node id inside the shard cluster replica.
|
|
// if there is no nodeID found, returns 0.
|
|
func (sc *ShardCluster) selectNodeInReplica(nodeIDs []int64) (int64, bool) {
|
|
for _, nodeID := range nodeIDs {
|
|
_, has := sc.getNode(nodeID)
|
|
if has {
|
|
return nodeID, true
|
|
}
|
|
}
|
|
return 0, false
|
|
}
|
|
|
|
func (sc *ShardCluster) updateShardClusterState(state shardClusterState) {
|
|
log := sc.getLogger()
|
|
old := sc.state.Load()
|
|
sc.state.Store(int32(state))
|
|
|
|
pc, _, _, _ := runtime.Caller(1)
|
|
callerName := runtime.FuncForPC(pc).Name()
|
|
|
|
log.Info("Shard Cluster update state",
|
|
zap.Int32("old state", old), zap.Int32("new state", int32(state)),
|
|
zap.String("caller", callerName))
|
|
}
|
|
|
|
// healthCheck iterate all segments to to check cluster could provide service.
|
|
func (sc *ShardCluster) healthCheck() {
|
|
for _, segment := range sc.segments {
|
|
if segment.state != segmentStateLoaded ||
|
|
segment.nodeID == common.InvalidNodeID { // segment in offline nodes
|
|
sc.updateShardClusterState(unavailable)
|
|
return
|
|
}
|
|
}
|
|
sc.updateShardClusterState(available)
|
|
}
|
|
|
|
// watchNodes handles node events.
|
|
func (sc *ShardCluster) watchNodes(evtCh <-chan nodeEvent) {
|
|
log := sc.getLogger()
|
|
for {
|
|
select {
|
|
case evt, ok := <-evtCh:
|
|
if !ok {
|
|
log.Warn("ShardCluster node channel closed")
|
|
return
|
|
}
|
|
switch evt.eventType {
|
|
case nodeAdd:
|
|
sc.addNode(evt)
|
|
case nodeDel:
|
|
sc.removeNode(evt)
|
|
}
|
|
case <-sc.closeCh:
|
|
log.Info("ShardCluster watchNode quit")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// watchSegments handles segment events.
|
|
func (sc *ShardCluster) watchSegments(evtCh <-chan segmentEvent) {
|
|
log := sc.getLogger()
|
|
for {
|
|
select {
|
|
case evt, ok := <-evtCh:
|
|
if !ok {
|
|
log.Warn("ShardCluster segment channel closed")
|
|
return
|
|
}
|
|
info, ok := sc.pickNode(evt)
|
|
if !ok {
|
|
log.Info("No node of event is in cluster, skip to process it",
|
|
zap.Int64s("nodes", evt.nodeIDs))
|
|
continue
|
|
}
|
|
switch evt.eventType {
|
|
case segmentAdd:
|
|
sc.updateSegment(info)
|
|
case segmentDel:
|
|
sc.removeSegment(info)
|
|
}
|
|
case <-sc.closeCh:
|
|
log.Info("ShardCluster watchSegments quit")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// getNode returns shallow copy of shardNode
|
|
func (sc *ShardCluster) getNode(nodeID int64) (*shardNode, bool) {
|
|
sc.mut.RLock()
|
|
defer sc.mut.RUnlock()
|
|
node, ok := sc.nodes[nodeID]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return &shardNode{
|
|
nodeID: node.nodeID,
|
|
nodeAddr: node.nodeAddr,
|
|
client: node.client, // shallow copy
|
|
}, true
|
|
}
|
|
|
|
// getSegment returns copy of shardSegmentInfo
|
|
func (sc *ShardCluster) getSegment(segmentID int64) (shardSegmentInfo, bool) {
|
|
sc.mut.RLock()
|
|
defer sc.mut.RUnlock()
|
|
segment, ok := sc.segments[segmentID]
|
|
return segment, ok
|
|
}
|
|
|
|
// segmentAllocations returns node to segments mappings.
|
|
// calling this function also increases the reference count of related segments.
|
|
func (sc *ShardCluster) segmentAllocations(partitionIDs []int64) (map[int64][]int64, int64) {
|
|
// check cluster serviceable
|
|
if !sc.serviceable() {
|
|
log.Warn("request segment allocations when cluster is not serviceable", zap.Int64("collectionID", sc.collectionID), zap.Int64("replicaID", sc.replicaID), zap.String("vchannelName", sc.vchannelName))
|
|
return map[int64][]int64{}, 0
|
|
}
|
|
sc.mutVersion.RLock()
|
|
defer sc.mutVersion.RUnlock()
|
|
// return allocation from current version and version id
|
|
return sc.currentVersion.GetAllocation(partitionIDs), sc.currentVersion.versionID
|
|
}
|
|
|
|
// finishUsage decreases the inUse count of provided segments
|
|
func (sc *ShardCluster) finishUsage(versionID int64) {
|
|
v, ok := sc.versions.Load(versionID)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
version := v.(*ShardClusterVersion)
|
|
version.FinishUsage()
|
|
|
|
// cleanup version if expired
|
|
sc.cleanupVersion(version)
|
|
}
|
|
|
|
// LoadSegments loads segments with shardCluster.
|
|
// shard cluster shall try to loadSegments in the follower then update the allocation.
|
|
func (sc *ShardCluster) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
|
|
log := sc.getLogger()
|
|
// add common log fields
|
|
log = log.With(zap.Int64("dstNodeID", req.GetDstNodeID()))
|
|
|
|
segmentIDs := make([]int64, 0, len(req.Infos))
|
|
for _, info := range req.Infos {
|
|
segmentIDs = append(segmentIDs, info.SegmentID)
|
|
}
|
|
log = log.With(zap.Int64s("segmentIDs", segmentIDs))
|
|
|
|
// notify follower to load segment
|
|
node, ok := sc.getNode(req.GetDstNodeID())
|
|
if !ok {
|
|
log.Warn("node not in cluster")
|
|
return fmt.Errorf("node not in cluster %d", req.GetDstNodeID())
|
|
}
|
|
|
|
req = proto.Clone(req).(*querypb.LoadSegmentsRequest)
|
|
req.Base.TargetID = req.GetDstNodeID()
|
|
resp, err := node.client.LoadSegments(ctx, req)
|
|
if err != nil {
|
|
log.Warn("failed to dispatch load segment request", zap.Error(err))
|
|
return err
|
|
}
|
|
if resp.GetErrorCode() != commonpb.ErrorCode_Success {
|
|
log.Warn("follower load segment failed", zap.String("reason", resp.GetReason()))
|
|
return fmt.Errorf("follower %d failed to load segment, reason %s", req.DstNodeID, resp.GetReason())
|
|
}
|
|
log.Debug("shardCluster has completed loading segments for:",
|
|
zap.Int64("collectionID:", req.CollectionID), zap.Int64("replicaID:", req.ReplicaID),
|
|
zap.Int64("sourceNodeId", req.SourceNodeID))
|
|
|
|
// update allocation
|
|
for _, info := range req.Infos {
|
|
sc.updateSegment(shardSegmentInfo{
|
|
nodeID: req.DstNodeID,
|
|
segmentID: info.SegmentID,
|
|
partitionID: info.PartitionID,
|
|
state: segmentStateLoaded,
|
|
version: req.GetVersion(),
|
|
})
|
|
}
|
|
|
|
// notify handoff wait online if any
|
|
sc.segmentCond.Broadcast()
|
|
|
|
sc.mutVersion.Lock()
|
|
defer sc.mutVersion.Unlock()
|
|
|
|
// update shardleader allocation view
|
|
allocations := sc.currentVersion.segments.Clone(filterNothing)
|
|
for _, info := range req.Infos {
|
|
allocations[info.SegmentID] = shardSegmentInfo{nodeID: req.DstNodeID, segmentID: info.SegmentID, partitionID: info.PartitionID, state: segmentStateLoaded}
|
|
}
|
|
|
|
lastVersion := sc.currentVersion
|
|
version := NewShardClusterVersion(sc.nextVersionID.Inc(), allocations, sc.currentVersion)
|
|
sc.versions.Store(version.versionID, version)
|
|
sc.currentVersion = version
|
|
sc.cleanupVersion(lastVersion)
|
|
|
|
return nil
|
|
}
|
|
|
|
// ReleaseSegments releases segments via ShardCluster.
|
|
// ShardCluster will wait all on-going search until finished, update the current version,
|
|
// then release the segments through follower.
|
|
func (sc *ShardCluster) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
|
|
log := sc.getLogger()
|
|
// add common log fields
|
|
log = log.With(zap.Int64s("segmentIDs", req.GetSegmentIDs()),
|
|
zap.String("scope", req.GetScope().String()),
|
|
zap.Bool("force", force))
|
|
|
|
//shardCluster.forceRemoveSegment(action.GetSegmentID())
|
|
offlineSegments := make(typeutil.UniqueSet)
|
|
if req.Scope != querypb.DataScope_Streaming {
|
|
offlineSegments.Insert(req.GetSegmentIDs()...)
|
|
}
|
|
|
|
var lastVersion *ShardClusterVersion
|
|
var err error
|
|
func() {
|
|
sc.mutVersion.Lock()
|
|
defer sc.mutVersion.Unlock()
|
|
|
|
var allocations SegmentsStatus
|
|
if sc.currentVersion != nil {
|
|
allocations = sc.currentVersion.segments.Clone(func(segmentID UniqueID, nodeID UniqueID) bool {
|
|
return (nodeID == req.NodeID || force) && offlineSegments.Contain(segmentID)
|
|
})
|
|
}
|
|
|
|
// generate a new version
|
|
versionID := sc.nextVersionID.Inc()
|
|
// remove offline segments in next version
|
|
// so incoming request will not have allocation of these segments
|
|
version := NewShardClusterVersion(versionID, allocations, sc.currentVersion)
|
|
sc.versions.Store(versionID, version)
|
|
|
|
// force release means current distribution has error
|
|
if !force {
|
|
// currentVersion shall be not nil
|
|
if sc.currentVersion != nil {
|
|
// wait for last version search done
|
|
<-sc.currentVersion.Expire()
|
|
lastVersion = sc.currentVersion
|
|
}
|
|
}
|
|
|
|
// set current version to new one
|
|
sc.currentVersion = version
|
|
|
|
// force release skips the release call
|
|
if force {
|
|
return
|
|
}
|
|
|
|
// try to release segments from nodes
|
|
node, ok := sc.getNode(req.GetNodeID())
|
|
if !ok {
|
|
log.Warn("node not in cluster", zap.Int64("nodeID", req.NodeID))
|
|
err = fmt.Errorf("node %d not in cluster ", req.NodeID)
|
|
return
|
|
}
|
|
|
|
req = proto.Clone(req).(*querypb.ReleaseSegmentsRequest)
|
|
req.Base.TargetID = req.GetNodeID()
|
|
resp, rerr := node.client.ReleaseSegments(ctx, req)
|
|
if err != nil {
|
|
log.Warn("failed to dispatch release segment request", zap.Error(err))
|
|
err = rerr
|
|
return
|
|
}
|
|
if resp.GetErrorCode() != commonpb.ErrorCode_Success {
|
|
log.Warn("follower release segment failed", zap.String("reason", resp.GetReason()))
|
|
err = fmt.Errorf("follower %d failed to release segment, reason %s", req.NodeID, resp.GetReason())
|
|
}
|
|
}()
|
|
sc.cleanupVersion(lastVersion)
|
|
|
|
sc.mut.Lock()
|
|
// do not delete segment if data scope is streaming
|
|
if req.GetScope() != querypb.DataScope_Streaming {
|
|
for _, segmentID := range req.SegmentIDs {
|
|
info, ok := sc.segments[segmentID]
|
|
if ok {
|
|
// otherwise, segment is on another node, do nothing
|
|
if force || info.nodeID == req.NodeID {
|
|
delete(sc.segments, segmentID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
sc.healthCheck()
|
|
sc.mut.Unlock()
|
|
|
|
return err
|
|
}
|
|
|
|
// cleanupVersion clean up version from map
|
|
func (sc *ShardCluster) cleanupVersion(version *ShardClusterVersion) {
|
|
// last version nil, just return
|
|
if version == nil {
|
|
return
|
|
}
|
|
|
|
// if version is still current one or still in use, return
|
|
if version.current.Load() || version.inUse.Load() > 0 {
|
|
return
|
|
}
|
|
|
|
sc.versions.Delete(version.versionID)
|
|
}
|
|
|
|
// waitSegmentsOnline waits until all provided segments is loaded.
|
|
func (sc *ShardCluster) waitSegmentsOnline(segments []shardSegmentInfo) {
|
|
sc.segmentCond.L.Lock()
|
|
for !sc.segmentsOnline(segments) {
|
|
sc.segmentCond.Wait()
|
|
}
|
|
sc.segmentCond.L.Unlock()
|
|
}
|
|
|
|
// checkOnline checks whether all segment info provided in online state.
|
|
func (sc *ShardCluster) segmentsOnline(segments []shardSegmentInfo) bool {
|
|
sc.mut.RLock()
|
|
defer sc.mut.RUnlock()
|
|
for _, segInfo := range segments {
|
|
segment, ok := sc.segments[segInfo.segmentID]
|
|
// check segment online on #specified Node#
|
|
if !ok || segment.state != segmentStateLoaded || segment.nodeID != segInfo.nodeID {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// GetStatistics returns the statistics on the shard cluster.
|
|
func (sc *ShardCluster) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, withStreaming withStreaming) ([]*internalpb.GetStatisticsResponse, error) {
|
|
if !sc.serviceable() {
|
|
return nil, fmt.Errorf("ShardCluster for %s replicaID %d is not available", sc.vchannelName, sc.replicaID)
|
|
}
|
|
if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) {
|
|
return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels())
|
|
}
|
|
|
|
// get node allocation and maintains the inUse reference count
|
|
segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
|
defer sc.finishUsage(versionID)
|
|
|
|
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)))
|
|
for nodeID, segmentIDs := range segAllocs {
|
|
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
|
|
}
|
|
|
|
// concurrent visiting nodes
|
|
var wg sync.WaitGroup
|
|
reqCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
var err error
|
|
var resultMut sync.Mutex
|
|
results := make([]*internalpb.GetStatisticsResponse, 0, len(segAllocs)) // count(nodes) + 1(growing)
|
|
|
|
// detect corresponding streaming search is done
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
streamErr := withStreaming(reqCtx)
|
|
resultMut.Lock()
|
|
defer resultMut.Unlock()
|
|
if streamErr != nil {
|
|
cancel()
|
|
// not set cancel error
|
|
if !errors.Is(streamErr, context.Canceled) {
|
|
err = fmt.Errorf("stream operation failed: %w", streamErr)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// dispatch request to followers
|
|
for nodeID, segments := range segAllocs {
|
|
nodeReq := &querypb.GetStatisticsRequest{
|
|
Req: req.GetReq(),
|
|
DmlChannels: req.GetDmlChannels(),
|
|
FromShardLeader: true,
|
|
Scope: querypb.DataScope_Historical,
|
|
SegmentIDs: segments,
|
|
}
|
|
node, ok := sc.getNode(nodeID)
|
|
if !ok { // meta mismatch, report error
|
|
return nil, WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName)
|
|
}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
partialResult, nodeErr := node.client.GetStatistics(reqCtx, nodeReq)
|
|
resultMut.Lock()
|
|
defer resultMut.Unlock()
|
|
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
|
cancel()
|
|
// not set cancel error
|
|
if !errors.Is(nodeErr, context.Canceled) {
|
|
err = fmt.Errorf("GetStatistic %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
|
|
}
|
|
return
|
|
}
|
|
results = append(results, partialResult)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
if err != nil {
|
|
log.Error(err.Error())
|
|
return nil, err
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// Search preforms search operation on shard cluster.
|
|
func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, withStreaming withStreaming) ([]*internalpb.SearchResults, error) {
|
|
log := log.Ctx(ctx)
|
|
if !sc.serviceable() {
|
|
err := WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName)
|
|
log.Warn("failed to search on shard",
|
|
zap.Int64("replicaID", sc.replicaID),
|
|
zap.String("channel", sc.vchannelName),
|
|
zap.Int32("state", sc.state.Load()),
|
|
zap.Any("version", sc.currentVersion),
|
|
zap.Error(err),
|
|
)
|
|
return nil, err
|
|
}
|
|
if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) {
|
|
return nil, fmt.Errorf("ShardCluster for %s does not match request channels :%v", sc.vchannelName, req.GetDmlChannels())
|
|
}
|
|
|
|
// get node allocation and maintains the inUse reference count
|
|
segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
|
defer sc.finishUsage(versionID)
|
|
|
|
log.Debug("cluster segment distribution", zap.Int("len", len(segAllocs)), zap.Int64s("partitionIDs", req.GetReq().GetPartitionIDs()))
|
|
for nodeID, segmentIDs := range segAllocs {
|
|
log.Debug("segments distribution", zap.Int64("nodeID", nodeID), zap.Int64s("segments", segmentIDs))
|
|
}
|
|
|
|
// concurrent visiting nodes
|
|
var wg sync.WaitGroup
|
|
reqCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
var err error
|
|
var resultMut sync.Mutex
|
|
results := make([]*internalpb.SearchResults, 0, len(segAllocs)) // count(nodes) + 1(growing)
|
|
|
|
// detect corresponding streaming search is done
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
streamErr := withStreaming(reqCtx)
|
|
resultMut.Lock()
|
|
defer resultMut.Unlock()
|
|
if streamErr != nil {
|
|
if err == nil {
|
|
err = fmt.Errorf("stream operation failed: %w", streamErr)
|
|
}
|
|
cancel()
|
|
}
|
|
}()
|
|
|
|
// dispatch request to followers
|
|
for nodeID, segments := range segAllocs {
|
|
internalReq := typeutil.Clone(req.GetReq())
|
|
internalReq.GetBase().TargetID = nodeID
|
|
nodeReq := &querypb.SearchRequest{
|
|
Req: internalReq,
|
|
DmlChannels: req.DmlChannels,
|
|
FromShardLeader: true,
|
|
Scope: querypb.DataScope_Historical,
|
|
SegmentIDs: segments,
|
|
}
|
|
node, ok := sc.getNode(nodeID)
|
|
if !ok { // meta dismatch, report error
|
|
return nil, fmt.Errorf("%w, node %d not found",
|
|
WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName),
|
|
nodeID,
|
|
)
|
|
}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
partialResult, nodeErr := node.client.Search(reqCtx, nodeReq)
|
|
resultMut.Lock()
|
|
defer resultMut.Unlock()
|
|
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
|
if err == nil {
|
|
err = fmt.Errorf("Search %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
|
|
}
|
|
cancel()
|
|
return
|
|
}
|
|
results = append(results, partialResult)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
if err != nil {
|
|
log.Warn("failed to do search",
|
|
zap.Int64("sourceID", req.GetReq().GetBase().GetSourceID()),
|
|
zap.Strings("channels", req.GetDmlChannels()),
|
|
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
|
|
zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// Query performs query operation on shard cluster.
|
|
func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, withStreaming withStreaming) ([]*internalpb.RetrieveResults, error) {
|
|
if !sc.serviceable() {
|
|
return nil, WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName)
|
|
}
|
|
|
|
// handles only the dml channel part, segment ids is dispatch by cluster itself
|
|
if !funcutil.SliceContain(req.GetDmlChannels(), sc.vchannelName) {
|
|
return nil, fmt.Errorf("ShardCluster for %s does not match to request channels :%v", sc.vchannelName, req.GetDmlChannels())
|
|
}
|
|
|
|
// get node allocation and maintains the inUse reference count
|
|
segAllocs, versionID := sc.segmentAllocations(req.GetReq().GetPartitionIDs())
|
|
defer sc.finishUsage(versionID)
|
|
|
|
// concurrent visiting nodes
|
|
var wg sync.WaitGroup
|
|
reqCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
var err error
|
|
var resultMut sync.Mutex
|
|
results := make([]*internalpb.RetrieveResults, 0, len(segAllocs)+1) // count(nodes) + 1(growing)
|
|
|
|
// detect corresponding streaming query is done
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
streamErr := withStreaming(reqCtx)
|
|
if streamErr != nil {
|
|
if err == nil {
|
|
err = fmt.Errorf("stream operation failed: %w", streamErr)
|
|
}
|
|
cancel()
|
|
}
|
|
}()
|
|
|
|
// dispatch request to followers
|
|
for nodeID, segments := range segAllocs {
|
|
internalReq := typeutil.Clone(req.GetReq())
|
|
internalReq.GetBase().TargetID = nodeID
|
|
nodeReq := &querypb.QueryRequest{
|
|
Req: internalReq,
|
|
FromShardLeader: true,
|
|
SegmentIDs: segments,
|
|
Scope: querypb.DataScope_Historical,
|
|
DmlChannels: req.DmlChannels,
|
|
}
|
|
node, ok := sc.getNode(nodeID)
|
|
if !ok { // meta dismatch, report error
|
|
return nil, WrapErrShardNotAvailable(sc.replicaID, sc.vchannelName)
|
|
}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
partialResult, nodeErr := node.client.Query(reqCtx, nodeReq)
|
|
resultMut.Lock()
|
|
defer resultMut.Unlock()
|
|
if nodeErr != nil || partialResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
|
err = fmt.Errorf("Query %d failed, reason %s err %w", node.nodeID, partialResult.GetStatus().GetReason(), nodeErr)
|
|
cancel()
|
|
return
|
|
}
|
|
results = append(results, partialResult)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
if err != nil {
|
|
log.Error(err.Error())
|
|
return nil, err
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (sc *ShardCluster) GetSegmentInfos() []shardSegmentInfo {
|
|
sc.mut.RLock()
|
|
defer sc.mut.RUnlock()
|
|
ret := make([]shardSegmentInfo, 0, len(sc.segments))
|
|
for _, info := range sc.segments {
|
|
ret = append(ret, info)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (sc *ShardCluster) getVersion() int64 {
|
|
return sc.version
|
|
}
|