milvus/internal/allocator/segment.go

270 lines
6.4 KiB
Go

package allocator
import (
"container/list"
"context"
"fmt"
"log"
"sort"
"time"
"github.com/cznic/mathutil"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
const (
SegCountPerRPC = 20000
ActiveTimeDuration = 100 // Second
)
type assignInfo struct {
collName string
partitionTag string
channelID int32
segInfo map[UniqueID]uint32 // segmentID->count map
expireTime time.Time
lastInsertTime time.Time
}
func (info *assignInfo) IsExpired(now time.Time) bool {
return now.Sub(info.expireTime) >= 0
}
func (info *assignInfo) IsActive(now time.Time) bool {
return now.Sub(info.lastInsertTime) <= ActiveTimeDuration*time.Second
}
func (info *assignInfo) IsEnough(count uint32) bool {
total := uint32(0)
for _, count := range info.segInfo {
total += count
}
return total >= count
}
type SegIDAssigner struct {
Allocator
assignInfos map[string]*list.List // collectionName -> *list.List
segReqs []*internalpb.SegIDRequest
canDoReqs []request
}
func NewSegIDAssigner(ctx context.Context, masterAddr string) (*SegIDAssigner, error) {
ctx1, cancel := context.WithCancel(ctx)
sa := &SegIDAssigner{
Allocator: Allocator{reqs: make(chan request, maxConcurrentRequests),
ctx: ctx1,
cancel: cancel,
masterAddress: masterAddr,
countPerRPC: SegCountPerRPC,
},
assignInfos: make(map[string]*list.List),
}
sa.tChan = &ticker{
updateInterval: time.Second,
}
sa.Allocator.syncFunc = sa.syncSegments
sa.Allocator.processFunc = sa.processFunc
sa.Allocator.checkFunc = sa.checkFunc
return sa, nil
}
func (sa *SegIDAssigner) collectExpired() {
now := time.Now()
for _, info := range sa.assignInfos {
for e := info.Front(); e != nil; e = e.Next() {
assign := e.Value.(*assignInfo)
if !assign.IsActive(now) || !assign.IsExpired(now) {
continue
}
sa.segReqs = append(sa.segReqs, &internalpb.SegIDRequest{
ChannelID: assign.channelID,
Count: sa.countPerRPC,
CollName: assign.collName,
PartitionTag: assign.partitionTag,
})
}
}
}
func (sa *SegIDAssigner) checkToDoReqs() {
if sa.toDoReqs == nil {
return
}
now := time.Now()
for _, req := range sa.toDoReqs {
segRequest := req.(*segRequest)
assign := sa.getAssign(segRequest.colName, segRequest.partition, segRequest.channelID)
if assign == nil || assign.IsExpired(now) || !assign.IsEnough(segRequest.count) {
sa.segReqs = append(sa.segReqs, &internalpb.SegIDRequest{
ChannelID: segRequest.channelID,
Count: segRequest.count,
CollName: segRequest.colName,
PartitionTag: segRequest.partition,
})
}
}
}
func (sa *SegIDAssigner) removeSegInfo(colName, partition string, channelID int32) {
assignInfos, ok := sa.assignInfos[colName]
if !ok {
return
}
cnt := assignInfos.Len()
if cnt == 0 {
return
}
for e := assignInfos.Front(); e != nil; e = e.Next() {
assign := e.Value.(*assignInfo)
if assign.partitionTag != partition || assign.channelID != channelID {
continue
}
assignInfos.Remove(e)
}
}
func (sa *SegIDAssigner) getAssign(colName, partition string, channelID int32) *assignInfo {
assignInfos, ok := sa.assignInfos[colName]
if !ok {
return nil
}
for e := assignInfos.Front(); e != nil; e = e.Next() {
info := e.Value.(*assignInfo)
if info.partitionTag != partition || info.channelID != channelID {
continue
}
return info
}
return nil
}
func (sa *SegIDAssigner) checkFunc(timeout bool) bool {
if timeout {
sa.collectExpired()
} else {
sa.checkToDoReqs()
}
return len(sa.segReqs) != 0
}
func (sa *SegIDAssigner) syncSegments() {
if len(sa.segReqs) == 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req := &internalpb.AssignSegIDRequest{
PeerID: 1,
Role: internalpb.PeerRole_Proxy,
PerChannelReq: sa.segReqs,
}
sa.segReqs = sa.segReqs[0:0]
fmt.Println("OOOOO", req.PerChannelReq)
resp, err := sa.masterClient.AssignSegmentID(ctx, req)
if resp.Status.GetErrorCode() != commonpb.ErrorCode_SUCCESS {
log.Println("GRPC AssignSegmentID Failed", resp, err)
return
}
now := time.Now()
expiredTime := now.Add(time.Millisecond * time.Duration(1000))
for _, info := range resp.PerChannelAssignment {
sa.removeSegInfo(info.CollName, info.PartitionTag, info.ChannelID)
}
for _, info := range resp.PerChannelAssignment {
assign := sa.getAssign(info.CollName, info.PartitionTag, info.ChannelID)
if assign == nil {
colInfos, ok := sa.assignInfos[info.CollName]
if !ok {
colInfos = list.New()
}
segInfo := make(map[UniqueID]uint32)
segInfo[info.SegID] = info.Count
newAssign := &assignInfo{
collName: info.CollName,
partitionTag: info.PartitionTag,
channelID: info.ChannelID,
segInfo: segInfo,
}
colInfos.PushBack(newAssign)
sa.assignInfos[info.CollName] = colInfos
} else {
assign.segInfo[info.SegID] = info.Count
assign.expireTime = expiredTime
assign.lastInsertTime = now
}
}
if err != nil {
log.Println("syncSemgnet Failed!!!!!")
return
}
}
func (sa *SegIDAssigner) processFunc(req request) error {
segRequest := req.(*segRequest)
assign := sa.getAssign(segRequest.colName, segRequest.partition, segRequest.channelID)
if assign == nil {
return errors.New("Failed to GetSegmentID")
}
keys := make([]UniqueID, len(assign.segInfo))
i := 0
for key := range assign.segInfo {
keys[i] = key
i++
}
reqCount := segRequest.count
resultSegInfo := make(map[UniqueID]uint32)
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
for _, key := range keys {
if reqCount <= 0 {
break
}
cur := assign.segInfo[key]
minCnt := mathutil.MinUint32(cur, reqCount)
resultSegInfo[key] = minCnt
cur -= minCnt
reqCount -= minCnt
if cur <= 0 {
delete(assign.segInfo, key)
} else {
assign.segInfo[key] = cur
}
}
segRequest.segInfo = resultSegInfo
return nil
}
func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32, count uint32) (map[UniqueID]uint32, error) {
req := &segRequest{
baseRequest: baseRequest{done: make(chan error), valid: false},
colName: colName,
partition: partition,
channelID: channelID,
count: count,
}
sa.reqs <- req
req.Wait()
if !req.IsValid() {
return nil, errors.New("GetSegmentID Failed")
}
return req.segInfo, nil
}