mirror of https://github.com/milvus-io/milvus.git
Add unittest for task scheduler (#7508)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/7518/head
parent
7025a6e925
commit
42b687bf48
|
@ -1,3 +1,14 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package grpcconfigs
|
||||
|
||||
import "math"
|
||||
|
|
|
@ -14,7 +14,6 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"sort"
|
||||
"sync"
|
||||
|
@ -37,83 +36,9 @@ type channelsMgr interface {
|
|||
removeAllDMLStream() error
|
||||
}
|
||||
|
||||
type (
|
||||
uniqueIntGenerator interface {
|
||||
get() int
|
||||
}
|
||||
naiveUniqueIntGenerator struct {
|
||||
now int
|
||||
mtx sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
func (generator *naiveUniqueIntGenerator) get() int {
|
||||
generator.mtx.Lock()
|
||||
defer func() {
|
||||
generator.now++
|
||||
generator.mtx.Unlock()
|
||||
}()
|
||||
return generator.now
|
||||
}
|
||||
|
||||
func newNaiveUniqueIntGenerator() *naiveUniqueIntGenerator {
|
||||
return &naiveUniqueIntGenerator{
|
||||
now: 0,
|
||||
}
|
||||
}
|
||||
|
||||
var uniqueIntGeneratorIns uniqueIntGenerator
|
||||
var getUniqueIntGeneratorInsOnce sync.Once
|
||||
|
||||
func getUniqueIntGeneratorIns() uniqueIntGenerator {
|
||||
getUniqueIntGeneratorInsOnce.Do(func() {
|
||||
uniqueIntGeneratorIns = newNaiveUniqueIntGenerator()
|
||||
})
|
||||
return uniqueIntGeneratorIns
|
||||
}
|
||||
|
||||
type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error)
|
||||
type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error)
|
||||
|
||||
type getChannelsService interface {
|
||||
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
|
||||
}
|
||||
|
||||
type mockGetChannelsService struct {
|
||||
collectionID2Channels map[UniqueID]map[vChan]pChan
|
||||
}
|
||||
|
||||
func newMockGetChannelsService() *mockGetChannelsService {
|
||||
return &mockGetChannelsService{
|
||||
collectionID2Channels: make(map[UniqueID]map[vChan]pChan),
|
||||
}
|
||||
}
|
||||
|
||||
func genUniqueStr() string {
|
||||
l := rand.Uint64()%100 + 1
|
||||
b := make([]byte, l)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%X", b)
|
||||
}
|
||||
|
||||
func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
channels, ok := m.collectionID2Channels[collectionID]
|
||||
if ok {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
channels = make(map[vChan]pChan)
|
||||
l := rand.Uint64()%10 + 1
|
||||
for i := 0; uint64(i) < l; i++ {
|
||||
channels[genUniqueStr()] = genUniqueStr()
|
||||
}
|
||||
|
||||
m.collectionID2Channels[collectionID] = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
type streamType int
|
||||
|
||||
const (
|
||||
|
|
|
@ -19,20 +19,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNaiveUniqueIntGenerator_get(t *testing.T) {
|
||||
exists := make(map[int]bool)
|
||||
num := 10
|
||||
|
||||
generator := newNaiveUniqueIntGenerator()
|
||||
|
||||
for i := 0; i < num; i++ {
|
||||
g := generator.get()
|
||||
_, ok := exists[g]
|
||||
assert.False(t, ok)
|
||||
exists[g] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelsMgrImpl_getChannels(t *testing.T) {
|
||||
master := newMockGetChannelsService()
|
||||
query := newMockGetChannelsService()
|
||||
|
|
|
@ -24,14 +24,6 @@ import (
|
|||
// ticker can update ts only when the minTs greater than the ts of ticker, we can use maxTs to update current later
|
||||
type getPChanStatisticsFuncType func() (map[pChan]*pChanStatistics, error)
|
||||
|
||||
// use interface tsoAllocator to keep channelsTimeTickerImpl testable
|
||||
type tsoAllocator interface {
|
||||
//Start() error
|
||||
AllocOne() (Timestamp, error)
|
||||
//Alloc(count uint32) ([]Timestamp, error)
|
||||
//ClearCache()
|
||||
}
|
||||
|
||||
type channelsTimeTicker interface {
|
||||
start() error
|
||||
close() error
|
||||
|
|
|
@ -25,17 +25,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockTsoAllocator struct {
|
||||
}
|
||||
|
||||
func (tso *mockTsoAllocator) AllocOne() (Timestamp, error) {
|
||||
return Timestamp(time.Now().UnixNano()), nil
|
||||
}
|
||||
|
||||
func newMockTsoAllocator() *mockTsoAllocator {
|
||||
return &mockTsoAllocator{}
|
||||
}
|
||||
|
||||
func newGetStatisticsFunc(pchans []pChan) getPChanStatisticsFuncType {
|
||||
totalPchan := len(pchans)
|
||||
pchanNum := rand.Uint64()%(uint64(totalPchan)) + 1
|
||||
|
|
|
@ -134,7 +134,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("schema", request.Schema))
|
||||
err := node.sched.DdQueue.Enqueue(cct)
|
||||
err := node.sched.ddQueue.Enqueue(cct)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -188,7 +188,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
|
|||
zap.String("role", Params.RoleName),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
err := node.sched.DdQueue.Enqueue(dct)
|
||||
err := node.sched.ddQueue.Enqueue(dct)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -240,7 +240,7 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle
|
|||
zap.String("role", Params.RoleName),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
err := node.sched.DdQueue.Enqueue(hct)
|
||||
err := node.sched.ddQueue.Enqueue(hct)
|
||||
if err != nil {
|
||||
return &milvuspb.BoolResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -294,7 +294,7 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
|
|||
zap.String("role", Params.RoleName),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
err := node.sched.DdQueue.Enqueue(lct)
|
||||
err := node.sched.ddQueue.Enqueue(lct)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -345,7 +345,7 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele
|
|||
zap.String("role", Params.RoleName),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
err := node.sched.DdQueue.Enqueue(rct)
|
||||
err := node.sched.ddQueue.Enqueue(rct)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -397,7 +397,7 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des
|
|||
zap.String("role", Params.RoleName),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
err := node.sched.DdQueue.Enqueue(dct)
|
||||
err := node.sched.ddQueue.Enqueue(dct)
|
||||
if err != nil {
|
||||
return &milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -453,7 +453,7 @@ func (node *Proxy) GetCollectionStatistics(ctx context.Context, request *milvusp
|
|||
zap.String("role", Params.RoleName),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName))
|
||||
err := node.sched.DdQueue.Enqueue(g)
|
||||
err := node.sched.ddQueue.Enqueue(g)
|
||||
if err != nil {
|
||||
return &milvuspb.GetCollectionStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -509,7 +509,7 @@ func (node *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCo
|
|||
log.Debug("ShowCollections enqueue",
|
||||
zap.String("role", Params.RoleName),
|
||||
zap.Any("request", request))
|
||||
err := node.sched.DdQueue.Enqueue(sct)
|
||||
err := node.sched.ddQueue.Enqueue(sct)
|
||||
if err != nil {
|
||||
return &milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -560,7 +560,7 @@ func (node *Proxy) CreatePartition(ctx context.Context, request *milvuspb.Create
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName))
|
||||
err := node.sched.DdQueue.Enqueue(cpt)
|
||||
err := node.sched.ddQueue.Enqueue(cpt)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -613,7 +613,7 @@ func (node *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPart
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName))
|
||||
err := node.sched.DdQueue.Enqueue(dpt)
|
||||
err := node.sched.ddQueue.Enqueue(dpt)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -668,7 +668,7 @@ func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartit
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName))
|
||||
err := node.sched.DdQueue.Enqueue(hpt)
|
||||
err := node.sched.ddQueue.Enqueue(hpt)
|
||||
if err != nil {
|
||||
return &milvuspb.BoolResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -726,7 +726,7 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames))
|
||||
err := node.sched.DdQueue.Enqueue(lpt)
|
||||
err := node.sched.ddQueue.Enqueue(lpt)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -779,7 +779,7 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames))
|
||||
err := node.sched.DdQueue.Enqueue(rpt)
|
||||
err := node.sched.ddQueue.Enqueue(rpt)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -834,7 +834,7 @@ func (node *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb
|
|||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName))
|
||||
err := node.sched.DdQueue.Enqueue(g)
|
||||
err := node.sched.ddQueue.Enqueue(g)
|
||||
if err != nil {
|
||||
return &milvuspb.GetPartitionStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -893,7 +893,7 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar
|
|||
log.Debug("ShowPartitions enqueue",
|
||||
zap.String("role", Params.RoleName),
|
||||
zap.Any("request", request))
|
||||
err := node.sched.DdQueue.Enqueue(spt)
|
||||
err := node.sched.ddQueue.Enqueue(spt)
|
||||
if err != nil {
|
||||
return &milvuspb.ShowPartitionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -943,7 +943,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.String("field", request.FieldName),
|
||||
zap.Any("extra_params", request.ExtraParams))
|
||||
err := node.sched.DdQueue.Enqueue(cit)
|
||||
err := node.sched.ddQueue.Enqueue(cit)
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -1001,7 +1001,7 @@ func (node *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.Describe
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.String("field", request.FieldName),
|
||||
zap.String("index name", request.IndexName))
|
||||
err := node.sched.DdQueue.Enqueue(dit)
|
||||
err := node.sched.ddQueue.Enqueue(dit)
|
||||
if err != nil {
|
||||
return &milvuspb.DescribeIndexResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1065,7 +1065,7 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.String("field", request.FieldName),
|
||||
zap.String("index name", request.IndexName))
|
||||
err := node.sched.DdQueue.Enqueue(dit)
|
||||
err := node.sched.ddQueue.Enqueue(dit)
|
||||
|
||||
if err != nil {
|
||||
return &commonpb.Status{
|
||||
|
@ -1127,7 +1127,7 @@ func (node *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb.
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.String("field", request.FieldName),
|
||||
zap.String("index name", request.IndexName))
|
||||
err := node.sched.DdQueue.Enqueue(gibpt)
|
||||
err := node.sched.ddQueue.Enqueue(gibpt)
|
||||
if err != nil {
|
||||
return &milvuspb.GetIndexBuildProgressResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1192,7 +1192,7 @@ func (node *Proxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndex
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.String("field", request.FieldName),
|
||||
zap.String("index name", request.IndexName))
|
||||
err := node.sched.DdQueue.Enqueue(dipt)
|
||||
err := node.sched.ddQueue.Enqueue(dipt)
|
||||
if err != nil {
|
||||
return &milvuspb.GetIndexStateResponse{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1299,7 +1299,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
}
|
||||
err = node.sched.DmQueue.Enqueue(it)
|
||||
err = node.sched.dmQueue.Enqueue(it)
|
||||
|
||||
log.Debug("Insert Task Enqueue",
|
||||
zap.Int64("msgID", it.BaseInsertTask.InsertRequest.Base.MsgID),
|
||||
|
@ -1366,7 +1366,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
|||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName),
|
||||
zap.String("expr", request.Expr))
|
||||
err := node.sched.DmQueue.Enqueue(dt)
|
||||
err := node.sched.dmQueue.Enqueue(dt)
|
||||
if err != nil {
|
||||
return &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1439,7 +1439,7 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
|||
zap.Any("dsl", request.Dsl),
|
||||
zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)),
|
||||
zap.Any("OutputFields", request.OutputFields))
|
||||
err := node.sched.DqQueue.Enqueue(qt)
|
||||
err := node.sched.dqQueue.Enqueue(qt)
|
||||
if err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1519,7 +1519,7 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
|
|||
zap.String("role", Params.RoleName),
|
||||
zap.String("db", request.DbName),
|
||||
zap.Any("collections", request.CollectionNames))
|
||||
err := node.sched.DdQueue.Enqueue(ft)
|
||||
err := node.sched.ddQueue.Enqueue(ft)
|
||||
if err != nil {
|
||||
resp.Status.Reason = err.Error()
|
||||
return resp, nil
|
||||
|
@ -1587,7 +1587,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||
zap.String("collection", queryRequest.CollectionName),
|
||||
zap.Any("partitions", queryRequest.PartitionNames))
|
||||
|
||||
err := node.sched.DqQueue.Enqueue(qt)
|
||||
err := node.sched.dqQueue.Enqueue(qt)
|
||||
if err != nil {
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -1669,7 +1669,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||
ids: ids.IdArray,
|
||||
}
|
||||
|
||||
err := node.sched.DqQueue.Enqueue(qt)
|
||||
err := node.sched.dqQueue.Enqueue(qt)
|
||||
if err != nil {
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
)
|
||||
|
||||
// use interface tsoAllocator to keep other components testable
|
||||
// include: channelsTimeTickerImpl, baseTaskQueue, taskScheduler
|
||||
type tsoAllocator interface {
|
||||
AllocOne() (Timestamp, error)
|
||||
}
|
||||
|
||||
// use interface idAllocatorInterface to keep other components testable
|
||||
// include: baseTaskQueue, taskScheduler
|
||||
type idAllocatorInterface interface {
|
||||
AllocOne() (UniqueID, error)
|
||||
}
|
||||
|
||||
// use timestampAllocatorInterface to keep other components testable
|
||||
// include: TimestampAllocator
|
||||
type timestampAllocatorInterface interface {
|
||||
AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error)
|
||||
}
|
||||
|
||||
type getChannelsService interface {
|
||||
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
|
||||
}
|
|
@ -13,6 +13,7 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
|
@ -37,3 +38,190 @@ func (tso *mockTimestampAllocatorInterface) AllocTimestamp(ctx context.Context,
|
|||
func newMockTimestampAllocatorInterface() timestampAllocatorInterface {
|
||||
return &mockTimestampAllocatorInterface{}
|
||||
}
|
||||
|
||||
type mockTsoAllocator struct {
|
||||
}
|
||||
|
||||
func (tso *mockTsoAllocator) AllocOne() (Timestamp, error) {
|
||||
return Timestamp(time.Now().UnixNano()), nil
|
||||
}
|
||||
|
||||
func newMockTsoAllocator() tsoAllocator {
|
||||
return &mockTsoAllocator{}
|
||||
}
|
||||
|
||||
type mockIDAllocatorInterface struct {
|
||||
}
|
||||
|
||||
func (m *mockIDAllocatorInterface) AllocOne() (UniqueID, error) {
|
||||
return UniqueID(getUniqueIntGeneratorIns().get()), nil
|
||||
}
|
||||
|
||||
func newMockIDAllocatorInterface() idAllocatorInterface {
|
||||
return &mockIDAllocatorInterface{}
|
||||
}
|
||||
|
||||
type mockGetChannelsService struct {
|
||||
collectionID2Channels map[UniqueID]map[vChan]pChan
|
||||
}
|
||||
|
||||
func newMockGetChannelsService() *mockGetChannelsService {
|
||||
return &mockGetChannelsService{
|
||||
collectionID2Channels: make(map[UniqueID]map[vChan]pChan),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
channels, ok := m.collectionID2Channels[collectionID]
|
||||
if ok {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
channels = make(map[vChan]pChan)
|
||||
l := rand.Uint64()%10 + 1
|
||||
for i := 0; uint64(i) < l; i++ {
|
||||
channels[genUniqueStr()] = genUniqueStr()
|
||||
}
|
||||
|
||||
m.collectionID2Channels[collectionID] = channels
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
type mockTask struct {
|
||||
*TaskCondition
|
||||
id UniqueID
|
||||
name string
|
||||
tType commonpb.MsgType
|
||||
ts Timestamp
|
||||
}
|
||||
|
||||
func (m *mockTask) TraceCtx() context.Context {
|
||||
return m.TaskCondition.ctx
|
||||
}
|
||||
|
||||
func (m *mockTask) ID() UniqueID {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *mockTask) SetID(uid UniqueID) {
|
||||
m.id = uid
|
||||
}
|
||||
|
||||
func (m *mockTask) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *mockTask) Type() commonpb.MsgType {
|
||||
return m.tType
|
||||
}
|
||||
|
||||
func (m *mockTask) BeginTs() Timestamp {
|
||||
return m.ts
|
||||
}
|
||||
|
||||
func (m *mockTask) EndTs() Timestamp {
|
||||
return m.ts
|
||||
}
|
||||
|
||||
func (m *mockTask) SetTs(ts Timestamp) {
|
||||
m.ts = ts
|
||||
}
|
||||
|
||||
func (m *mockTask) OnEnqueue() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTask) Execute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMockTask(ctx context.Context) *mockTask {
|
||||
return &mockTask{
|
||||
TaskCondition: NewTaskCondition(ctx),
|
||||
id: UniqueID(getUniqueIntGeneratorIns().get()),
|
||||
name: genUniqueStr(),
|
||||
tType: commonpb.MsgType_Undefined,
|
||||
ts: Timestamp(time.Now().Nanosecond()),
|
||||
}
|
||||
}
|
||||
|
||||
func newDefaultMockTask() *mockTask {
|
||||
return newMockTask(context.Background())
|
||||
}
|
||||
|
||||
type mockDdlTask struct {
|
||||
*mockTask
|
||||
}
|
||||
|
||||
func newMockDdlTask(ctx context.Context) *mockDdlTask {
|
||||
return &mockDdlTask{
|
||||
mockTask: newMockTask(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func newDefaultMockDdlTask() *mockDdlTask {
|
||||
return newMockDdlTask(context.Background())
|
||||
}
|
||||
|
||||
type mockDmlTask struct {
|
||||
*mockTask
|
||||
vchans []vChan
|
||||
pchans []pChan
|
||||
}
|
||||
|
||||
func (m *mockDmlTask) getChannels() ([]vChan, error) {
|
||||
return m.vchans, nil
|
||||
}
|
||||
|
||||
func (m *mockDmlTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
||||
ret := make(map[pChan]pChanStatistics)
|
||||
for _, pchan := range m.pchans {
|
||||
ret[pchan] = pChanStatistics{
|
||||
minTs: m.ts,
|
||||
maxTs: m.ts,
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func newMockDmlTask(ctx context.Context) *mockDmlTask {
|
||||
shardNum := 2
|
||||
|
||||
vchans := make([]vChan, 0, shardNum)
|
||||
pchans := make([]pChan, 0, shardNum)
|
||||
|
||||
for i := 0; i < shardNum; i++ {
|
||||
vchans = append(vchans, genUniqueStr())
|
||||
pchans = append(pchans, genUniqueStr())
|
||||
}
|
||||
|
||||
return &mockDmlTask{
|
||||
mockTask: newMockTask(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func newDefaultMockDmlTask() *mockDmlTask {
|
||||
return newMockDmlTask(context.Background())
|
||||
}
|
||||
|
||||
type mockDqlTask struct {
|
||||
*mockTask
|
||||
}
|
||||
|
||||
func newMockDqlTask(ctx context.Context) *mockDqlTask {
|
||||
return &mockDqlTask{
|
||||
mockTask: newMockTask(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func newDefaultMockDqlTask() *mockDqlTask {
|
||||
return newMockDqlTask(context.Background())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import "sync"
|
||||
|
||||
type (
|
||||
uniqueIntGenerator interface {
|
||||
get() int
|
||||
}
|
||||
naiveUniqueIntGenerator struct {
|
||||
now int
|
||||
mtx sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
func (generator *naiveUniqueIntGenerator) get() int {
|
||||
generator.mtx.Lock()
|
||||
defer func() {
|
||||
generator.now++
|
||||
generator.mtx.Unlock()
|
||||
}()
|
||||
return generator.now
|
||||
}
|
||||
|
||||
func newNaiveUniqueIntGenerator() *naiveUniqueIntGenerator {
|
||||
return &naiveUniqueIntGenerator{
|
||||
now: 0,
|
||||
}
|
||||
}
|
||||
|
||||
var uniqueIntGeneratorIns uniqueIntGenerator
|
||||
var getUniqueIntGeneratorInsOnce sync.Once
|
||||
|
||||
func getUniqueIntGeneratorIns() uniqueIntGenerator {
|
||||
getUniqueIntGeneratorInsOnce.Do(func() {
|
||||
uniqueIntGeneratorIns = newNaiveUniqueIntGenerator()
|
||||
})
|
||||
return uniqueIntGeneratorIns
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNaiveUniqueIntGenerator_get(t *testing.T) {
|
||||
exists := make(map[int]bool)
|
||||
num := 10
|
||||
|
||||
generator := newNaiveUniqueIntGenerator()
|
||||
|
||||
for i := 0; i < num; i++ {
|
||||
g := generator.get()
|
||||
_, ok := exists[g]
|
||||
assert.False(t, ok)
|
||||
exists[g] = true
|
||||
}
|
||||
}
|
|
@ -63,7 +63,7 @@ type Proxy struct {
|
|||
|
||||
chMgr channelsMgr
|
||||
|
||||
sched *TaskScheduler
|
||||
sched *taskScheduler
|
||||
tick *timeTick
|
||||
|
||||
chTicker channelsTimeTicker
|
||||
|
@ -256,7 +256,7 @@ func (node *Proxy) Init() error {
|
|||
chMgr := newChannelsMgrImpl(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
|
||||
node.chMgr = chMgr
|
||||
|
||||
node.sched, err = NewTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
|
||||
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -100,7 +100,6 @@ type dmlTask interface {
|
|||
task
|
||||
getChannels() ([]vChan, error)
|
||||
getPChanStats() (map[pChan]pChanStatistics, error)
|
||||
getChannelsTimerTicker() channelsTimeTicker
|
||||
}
|
||||
|
||||
type BaseInsertTask = msgstream.InsertMsg
|
||||
|
@ -155,10 +154,6 @@ func (it *InsertTask) EndTs() Timestamp {
|
|||
return it.EndTimestamp
|
||||
}
|
||||
|
||||
func (it *InsertTask) getChannelsTimerTicker() channelsTimeTicker {
|
||||
return it.chTicker
|
||||
}
|
||||
|
||||
func (it *InsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
||||
ret := make(map[pChan]pChanStatistics)
|
||||
|
||||
|
@ -192,6 +187,17 @@ func (it *InsertTask) getChannels() ([]pChan, error) {
|
|||
return nil, err
|
||||
}
|
||||
channels, err = it.chMgr.getChannels(collID)
|
||||
if err == nil {
|
||||
for _, pchan := range channels {
|
||||
err := it.chTicker.addPChan(pchan)
|
||||
if err != nil {
|
||||
log.Warn("failed to add pchan to channels time ticker",
|
||||
zap.Error(err),
|
||||
zap.Int64("collection id", collID),
|
||||
zap.String("pchan", pchan))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return channels, err
|
||||
}
|
||||
|
@ -1023,6 +1029,17 @@ func (it *InsertTask) Execute(ctx context.Context) error {
|
|||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
channels, err := it.chMgr.getChannels(collID)
|
||||
if err == nil {
|
||||
for _, pchan := range channels {
|
||||
err := it.chTicker.addPChan(pchan)
|
||||
if err != nil {
|
||||
log.Warn("failed to add pchan to channels time ticker",
|
||||
zap.Error(err),
|
||||
zap.String("pchan", pchan))
|
||||
}
|
||||
}
|
||||
}
|
||||
stream, err = it.chMgr.getDMLStream(collID)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
|
|
|
@ -19,9 +19,10 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
|
@ -31,7 +32,7 @@ import (
|
|||
oplog "github.com/opentracing/opentracing-go/log"
|
||||
)
|
||||
|
||||
type TaskQueue interface {
|
||||
type taskQueue interface {
|
||||
utChan() <-chan int
|
||||
utEmpty() bool
|
||||
utFull() bool
|
||||
|
@ -45,7 +46,10 @@ type TaskQueue interface {
|
|||
Enqueue(t task) error
|
||||
}
|
||||
|
||||
type BaseTaskQueue struct {
|
||||
// TODO(dragondriver): load from config
|
||||
const maxTaskNum = 1024
|
||||
|
||||
type baseTaskQueue struct {
|
||||
unissuedTasks *list.List
|
||||
activeTasks map[UniqueID]task
|
||||
utLock sync.RWMutex
|
||||
|
@ -56,24 +60,25 @@ type BaseTaskQueue struct {
|
|||
|
||||
utBufChan chan int // to block scheduler
|
||||
|
||||
sched *TaskScheduler
|
||||
tsoAllocatorIns tsoAllocator
|
||||
idAllocatorIns idAllocatorInterface
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) utChan() <-chan int {
|
||||
func (queue *baseTaskQueue) utChan() <-chan int {
|
||||
return queue.utBufChan
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) utEmpty() bool {
|
||||
func (queue *baseTaskQueue) utEmpty() bool {
|
||||
queue.utLock.RLock()
|
||||
defer queue.utLock.RUnlock()
|
||||
return queue.unissuedTasks.Len() == 0
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) utFull() bool {
|
||||
func (queue *baseTaskQueue) utFull() bool {
|
||||
return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) addUnissuedTask(t task) error {
|
||||
func (queue *baseTaskQueue) addUnissuedTask(t task) error {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
|
||||
|
@ -85,7 +90,7 @@ func (queue *BaseTaskQueue) addUnissuedTask(t task) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) FrontUnissuedTask() task {
|
||||
func (queue *baseTaskQueue) FrontUnissuedTask() task {
|
||||
queue.utLock.RLock()
|
||||
defer queue.utLock.RUnlock()
|
||||
|
||||
|
@ -97,7 +102,7 @@ func (queue *BaseTaskQueue) FrontUnissuedTask() task {
|
|||
return queue.unissuedTasks.Front().Value.(task)
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) PopUnissuedTask() task {
|
||||
func (queue *baseTaskQueue) PopUnissuedTask() task {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
|
||||
|
@ -112,7 +117,7 @@ func (queue *BaseTaskQueue) PopUnissuedTask() task {
|
|||
return ft.Value.(task)
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) AddActiveTask(t task) {
|
||||
func (queue *baseTaskQueue) AddActiveTask(t task) {
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Unlock()
|
||||
tID := t.ID()
|
||||
|
@ -124,7 +129,7 @@ func (queue *BaseTaskQueue) AddActiveTask(t task) {
|
|||
queue.activeTasks[tID] = t
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) PopActiveTask(tID UniqueID) task {
|
||||
func (queue *baseTaskQueue) PopActiveTask(tID UniqueID) task {
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Unlock()
|
||||
t, ok := queue.activeTasks[tID]
|
||||
|
@ -137,7 +142,7 @@ func (queue *BaseTaskQueue) PopActiveTask(tID UniqueID) task {
|
|||
return t
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) getTaskByReqID(reqID UniqueID) task {
|
||||
func (queue *baseTaskQueue) getTaskByReqID(reqID UniqueID) task {
|
||||
queue.utLock.RLock()
|
||||
defer queue.utLock.RUnlock()
|
||||
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
|
||||
|
@ -157,7 +162,7 @@ func (queue *BaseTaskQueue) getTaskByReqID(reqID UniqueID) task {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool {
|
||||
func (queue *baseTaskQueue) TaskDoneTest(ts Timestamp) bool {
|
||||
queue.utLock.RLock()
|
||||
defer queue.utLock.RUnlock()
|
||||
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
|
||||
|
@ -177,19 +182,19 @@ func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) Enqueue(t task) error {
|
||||
func (queue *baseTaskQueue) Enqueue(t task) error {
|
||||
err := t.OnEnqueue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ts, err := queue.sched.tsoAllocator.AllocOne()
|
||||
ts, err := queue.tsoAllocatorIns.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.SetTs(ts)
|
||||
|
||||
reqID, err := queue.sched.idAllocator.AllocOne()
|
||||
reqID, err := queue.idAllocatorIns.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -198,8 +203,21 @@ func (queue *BaseTaskQueue) Enqueue(t task) error {
|
|||
return queue.addUnissuedTask(t)
|
||||
}
|
||||
|
||||
type DdTaskQueue struct {
|
||||
BaseTaskQueue
|
||||
func newBaseTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *baseTaskQueue {
|
||||
return &baseTaskQueue{
|
||||
unissuedTasks: list.New(),
|
||||
activeTasks: make(map[UniqueID]task),
|
||||
utLock: sync.RWMutex{},
|
||||
atLock: sync.RWMutex{},
|
||||
maxTaskNum: maxTaskNum,
|
||||
utBufChan: make(chan int, maxTaskNum),
|
||||
tsoAllocatorIns: tsoAllocatorIns,
|
||||
idAllocatorIns: idAllocatorIns,
|
||||
}
|
||||
}
|
||||
|
||||
type ddTaskQueue struct {
|
||||
*baseTaskQueue
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -208,19 +226,19 @@ type pChanStatInfo struct {
|
|||
tsSet map[Timestamp]struct{}
|
||||
}
|
||||
|
||||
type DmTaskQueue struct {
|
||||
BaseTaskQueue
|
||||
type dmTaskQueue struct {
|
||||
*baseTaskQueue
|
||||
lock sync.Mutex
|
||||
|
||||
statsLock sync.RWMutex
|
||||
pChanStatisticsInfos map[pChan]*pChanStatInfo
|
||||
}
|
||||
|
||||
func (queue *DmTaskQueue) Enqueue(t task) error {
|
||||
func (queue *dmTaskQueue) Enqueue(t task) error {
|
||||
queue.lock.Lock()
|
||||
defer queue.lock.Unlock()
|
||||
|
||||
err := queue.BaseTaskQueue.Enqueue(t)
|
||||
err := queue.baseTaskQueue.Enqueue(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -229,13 +247,13 @@ func (queue *DmTaskQueue) Enqueue(t task) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (queue *DmTaskQueue) PopActiveTask(tID UniqueID) task {
|
||||
func (queue *dmTaskQueue) PopActiveTask(tID UniqueID) task {
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Unlock()
|
||||
t, ok := queue.activeTasks[tID]
|
||||
if ok {
|
||||
delete(queue.activeTasks, tID)
|
||||
log.Debug("Proxy DmTaskQueue popPChanStats", zap.Any("tID", t.ID()))
|
||||
log.Debug("Proxy dmTaskQueue popPChanStats", zap.Any("tID", t.ID()))
|
||||
queue.popPChanStats(t)
|
||||
} else {
|
||||
log.Debug("Proxy task not in active task list!", zap.Any("tID", tID))
|
||||
|
@ -243,11 +261,11 @@ func (queue *DmTaskQueue) PopActiveTask(tID UniqueID) task {
|
|||
return t
|
||||
}
|
||||
|
||||
func (queue *DmTaskQueue) addPChanStats(t task) error {
|
||||
func (queue *dmTaskQueue) addPChanStats(t task) error {
|
||||
if dmT, ok := t.(dmlTask); ok {
|
||||
stats, err := dmT.getPChanStats()
|
||||
if err != nil {
|
||||
log.Debug("Proxy DmTaskQueue addPChanStats", zap.Any("tID", t.ID()),
|
||||
log.Debug("Proxy dmTaskQueue addPChanStats", zap.Any("tID", t.ID()),
|
||||
zap.Any("stats", stats), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
@ -262,7 +280,6 @@ func (queue *DmTaskQueue) addPChanStats(t task) error {
|
|||
},
|
||||
}
|
||||
queue.pChanStatisticsInfos[cName] = info
|
||||
dmT.getChannelsTimerTicker().addPChan(cName)
|
||||
} else {
|
||||
if info.minTs > stat.minTs {
|
||||
queue.pChanStatisticsInfos[cName].minTs = stat.minTs
|
||||
|
@ -280,7 +297,7 @@ func (queue *DmTaskQueue) addPChanStats(t task) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (queue *DmTaskQueue) popPChanStats(t task) error {
|
||||
func (queue *dmTaskQueue) popPChanStats(t task) error {
|
||||
if dmT, ok := t.(dmlTask); ok {
|
||||
channels, err := dmT.getChannels()
|
||||
if err != nil {
|
||||
|
@ -306,12 +323,12 @@ func (queue *DmTaskQueue) popPChanStats(t task) error {
|
|||
}
|
||||
queue.statsLock.Unlock()
|
||||
} else {
|
||||
return fmt.Errorf("Proxy DmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
|
||||
return fmt.Errorf("Proxy dmTaskQueue popPChanStats reflect to dmlTask failed, tID:%v", t.ID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (queue *DmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {
|
||||
func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) {
|
||||
|
||||
ret := make(map[pChan]*pChanStatistics)
|
||||
queue.statsLock.RLock()
|
||||
|
@ -325,60 +342,39 @@ func (queue *DmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
type DqTaskQueue struct {
|
||||
BaseTaskQueue
|
||||
type dqTaskQueue struct {
|
||||
*baseTaskQueue
|
||||
}
|
||||
|
||||
func (queue *DdTaskQueue) Enqueue(t task) error {
|
||||
func (queue *ddTaskQueue) Enqueue(t task) error {
|
||||
queue.lock.Lock()
|
||||
defer queue.lock.Unlock()
|
||||
return queue.BaseTaskQueue.Enqueue(t)
|
||||
return queue.baseTaskQueue.Enqueue(t)
|
||||
}
|
||||
|
||||
func NewDdTaskQueue(sched *TaskScheduler) *DdTaskQueue {
|
||||
return &DdTaskQueue{
|
||||
BaseTaskQueue: BaseTaskQueue{
|
||||
unissuedTasks: list.New(),
|
||||
activeTasks: make(map[UniqueID]task),
|
||||
maxTaskNum: 1024,
|
||||
utBufChan: make(chan int, 1024),
|
||||
sched: sched,
|
||||
},
|
||||
func newDdTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *ddTaskQueue {
|
||||
return &ddTaskQueue{
|
||||
baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns),
|
||||
}
|
||||
}
|
||||
|
||||
func NewDmTaskQueue(sched *TaskScheduler) *DmTaskQueue {
|
||||
return &DmTaskQueue{
|
||||
BaseTaskQueue: BaseTaskQueue{
|
||||
unissuedTasks: list.New(),
|
||||
activeTasks: make(map[UniqueID]task),
|
||||
maxTaskNum: 1024,
|
||||
utBufChan: make(chan int, 1024),
|
||||
sched: sched,
|
||||
},
|
||||
func newDmTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *dmTaskQueue {
|
||||
return &dmTaskQueue{
|
||||
baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns),
|
||||
pChanStatisticsInfos: make(map[pChan]*pChanStatInfo),
|
||||
}
|
||||
}
|
||||
|
||||
func NewDqTaskQueue(sched *TaskScheduler) *DqTaskQueue {
|
||||
return &DqTaskQueue{
|
||||
BaseTaskQueue: BaseTaskQueue{
|
||||
unissuedTasks: list.New(),
|
||||
activeTasks: make(map[UniqueID]task),
|
||||
maxTaskNum: 1024,
|
||||
utBufChan: make(chan int, 1024),
|
||||
sched: sched,
|
||||
},
|
||||
func newDqTaskQueue(tsoAllocatorIns tsoAllocator, idAllocatorIns idAllocatorInterface) *dqTaskQueue {
|
||||
return &dqTaskQueue{
|
||||
baseTaskQueue: newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns),
|
||||
}
|
||||
}
|
||||
|
||||
type TaskScheduler struct {
|
||||
DdQueue TaskQueue
|
||||
DmQueue *DmTaskQueue
|
||||
DqQueue TaskQueue
|
||||
|
||||
idAllocator *allocator.IDAllocator
|
||||
tsoAllocator *TimestampAllocator
|
||||
type taskScheduler struct {
|
||||
ddQueue taskQueue
|
||||
dmQueue *dmTaskQueue
|
||||
dqQueue taskQueue
|
||||
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
|
@ -387,51 +383,49 @@ type TaskScheduler struct {
|
|||
msFactory msgstream.Factory
|
||||
}
|
||||
|
||||
func NewTaskScheduler(ctx context.Context,
|
||||
idAllocator *allocator.IDAllocator,
|
||||
tsoAllocator *TimestampAllocator,
|
||||
factory msgstream.Factory) (*TaskScheduler, error) {
|
||||
func newTaskScheduler(ctx context.Context,
|
||||
idAllocatorIns idAllocatorInterface,
|
||||
tsoAllocatorIns tsoAllocator,
|
||||
factory msgstream.Factory) (*taskScheduler, error) {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
s := &TaskScheduler{
|
||||
idAllocator: idAllocator,
|
||||
tsoAllocator: tsoAllocator,
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
msFactory: factory,
|
||||
s := &taskScheduler{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
msFactory: factory,
|
||||
}
|
||||
s.DdQueue = NewDdTaskQueue(s)
|
||||
s.DmQueue = NewDmTaskQueue(s)
|
||||
s.DqQueue = NewDqTaskQueue(s)
|
||||
s.ddQueue = newDdTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
s.dmQueue = newDmTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
s.dqQueue = newDqTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) scheduleDdTask() task {
|
||||
return sched.DdQueue.PopUnissuedTask()
|
||||
func (sched *taskScheduler) scheduleDdTask() task {
|
||||
return sched.ddQueue.PopUnissuedTask()
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) scheduleDmTask() task {
|
||||
return sched.DmQueue.PopUnissuedTask()
|
||||
func (sched *taskScheduler) scheduleDmTask() task {
|
||||
return sched.dmQueue.PopUnissuedTask()
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) scheduleDqTask() task {
|
||||
return sched.DqQueue.PopUnissuedTask()
|
||||
func (sched *taskScheduler) scheduleDqTask() task {
|
||||
return sched.dqQueue.PopUnissuedTask()
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) getTaskByReqID(collMeta UniqueID) task {
|
||||
if t := sched.DdQueue.getTaskByReqID(collMeta); t != nil {
|
||||
func (sched *taskScheduler) getTaskByReqID(collMeta UniqueID) task {
|
||||
if t := sched.ddQueue.getTaskByReqID(collMeta); t != nil {
|
||||
return t
|
||||
}
|
||||
if t := sched.DmQueue.getTaskByReqID(collMeta); t != nil {
|
||||
if t := sched.dmQueue.getTaskByReqID(collMeta); t != nil {
|
||||
return t
|
||||
}
|
||||
if t := sched.DqQueue.getTaskByReqID(collMeta); t != nil {
|
||||
if t := sched.dqQueue.getTaskByReqID(collMeta); t != nil {
|
||||
return t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
|
||||
func (sched *taskScheduler) processTask(t task, q taskQueue) {
|
||||
span, ctx := trace.StartSpanFromContext(t.TraceCtx(),
|
||||
opentracing.Tags{
|
||||
"Type": t.Name(),
|
||||
|
@ -469,47 +463,47 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
|
|||
err = t.PostExecute(ctx)
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) definitionLoop() {
|
||||
func (sched *taskScheduler) definitionLoop() {
|
||||
defer sched.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-sched.ctx.Done():
|
||||
return
|
||||
case <-sched.DdQueue.utChan():
|
||||
if !sched.DdQueue.utEmpty() {
|
||||
case <-sched.ddQueue.utChan():
|
||||
if !sched.ddQueue.utEmpty() {
|
||||
t := sched.scheduleDdTask()
|
||||
sched.processTask(t, sched.DdQueue)
|
||||
sched.processTask(t, sched.ddQueue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) manipulationLoop() {
|
||||
func (sched *taskScheduler) manipulationLoop() {
|
||||
defer sched.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-sched.ctx.Done():
|
||||
return
|
||||
case <-sched.DmQueue.utChan():
|
||||
if !sched.DmQueue.utEmpty() {
|
||||
case <-sched.dmQueue.utChan():
|
||||
if !sched.dmQueue.utEmpty() {
|
||||
t := sched.scheduleDmTask()
|
||||
go sched.processTask(t, sched.DmQueue)
|
||||
go sched.processTask(t, sched.dmQueue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) queryLoop() {
|
||||
func (sched *taskScheduler) queryLoop() {
|
||||
defer sched.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sched.ctx.Done():
|
||||
return
|
||||
case <-sched.DqQueue.utChan():
|
||||
if !sched.DqQueue.utEmpty() {
|
||||
case <-sched.dqQueue.utChan():
|
||||
if !sched.dqQueue.utEmpty() {
|
||||
t := sched.scheduleDqTask()
|
||||
go sched.processTask(t, sched.DqQueue)
|
||||
go sched.processTask(t, sched.dqQueue)
|
||||
} else {
|
||||
log.Debug("query queue is empty ...")
|
||||
}
|
||||
|
@ -561,25 +555,6 @@ func newQueryResultBuf() *queryResultBuf {
|
|||
}
|
||||
}
|
||||
|
||||
func setContain(m1, m2 map[interface{}]struct{}) bool {
|
||||
log.Debug("Proxy task_scheduler setContain", zap.Any("len(m1)", len(m1)),
|
||||
zap.Any("len(m2)", len(m2)))
|
||||
if len(m1) < len(m2) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k2 := range m2 {
|
||||
_, ok := m1[k2]
|
||||
log.Debug("Proxy task_scheduler setContain", zap.Any("k2", fmt.Sprintf("%v", k2)),
|
||||
zap.Any("ok", ok))
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (sr *resultBufHeader) readyToReduce() bool {
|
||||
if sr.haveError {
|
||||
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("haveError", true))
|
||||
|
@ -608,7 +583,7 @@ func (sr *resultBufHeader) readyToReduce() bool {
|
|||
sealedGlobalSegmentIDsStrMap[x.(int64)] = 1
|
||||
}
|
||||
|
||||
ret1 := setContain(sr.receivedVChansSet, sr.usedVChans)
|
||||
ret1 := funcutil.SetContain(sr.receivedVChansSet, sr.usedVChans)
|
||||
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("receivedVChansSet", receivedVChansSetStrMap),
|
||||
zap.Any("usedVChans", usedVChansSetStrMap),
|
||||
zap.Any("receivedSealedSegmentIDsSet", sealedSegmentIDsStrMap),
|
||||
|
@ -618,7 +593,7 @@ func (sr *resultBufHeader) readyToReduce() bool {
|
|||
if !ret1 {
|
||||
return false
|
||||
}
|
||||
ret := setContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet)
|
||||
ret := funcutil.SetContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet)
|
||||
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("ret", ret))
|
||||
return ret
|
||||
}
|
||||
|
@ -658,7 +633,7 @@ func (qr *queryResultBuf) addPartialResult(result *internalpb.RetrieveResults) {
|
|||
result.GlobalSealedSegmentIDs)
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) collectResultLoop() {
|
||||
func (sched *taskScheduler) collectResultLoop() {
|
||||
defer sched.wg.Done()
|
||||
|
||||
queryResultMsgStream, _ := sched.msFactory.NewQueryMsgStream(sched.ctx)
|
||||
|
@ -862,7 +837,7 @@ func (sched *TaskScheduler) collectResultLoop() {
|
|||
}
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) Start() error {
|
||||
func (sched *taskScheduler) Start() error {
|
||||
sched.wg.Add(1)
|
||||
go sched.definitionLoop()
|
||||
|
||||
|
@ -878,17 +853,17 @@ func (sched *TaskScheduler) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) Close() {
|
||||
func (sched *taskScheduler) Close() {
|
||||
sched.cancel()
|
||||
sched.wg.Wait()
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) TaskDoneTest(ts Timestamp) bool {
|
||||
ddTaskDone := sched.DdQueue.TaskDoneTest(ts)
|
||||
dmTaskDone := sched.DmQueue.TaskDoneTest(ts)
|
||||
func (sched *taskScheduler) TaskDoneTest(ts Timestamp) bool {
|
||||
ddTaskDone := sched.ddQueue.TaskDoneTest(ts)
|
||||
dmTaskDone := sched.dmQueue.TaskDoneTest(ts)
|
||||
return ddTaskDone && dmTaskDone
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) getPChanStatistics() (map[pChan]*pChanStatistics, error) {
|
||||
return sched.DmQueue.getPChanStatsInfo()
|
||||
func (sched *taskScheduler) getPChanStatistics() (map[pChan]*pChanStatistics, error) {
|
||||
return sched.dmQueue.getPChanStatsInfo()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,528 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBaseTaskQueue(t *testing.T) {
|
||||
var err error
|
||||
var unissuedTask task
|
||||
var activeTask task
|
||||
var done bool
|
||||
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
idAllocatorIns := newMockIDAllocatorInterface()
|
||||
queue := newBaseTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
assert.NotNil(t, queue)
|
||||
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
st := newDefaultMockTask()
|
||||
stID := st.ID()
|
||||
stTs := st.BeginTs()
|
||||
|
||||
// no task in queue
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(stID)
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(stTs)
|
||||
assert.True(t, done)
|
||||
|
||||
// task enqueue, only one task in queue
|
||||
|
||||
err = queue.Enqueue(st)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
assert.Equal(t, 1, queue.unissuedTasks.Len())
|
||||
assert.Equal(t, 1, len(queue.utChan()))
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test active list, no task in queue
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
// test active list, no task in unissued list, only one task in active list
|
||||
|
||||
queue.AddActiveTask(unissuedTask)
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test utFull
|
||||
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
|
||||
for i := 0; i < int(queue.maxTaskNum); i++ {
|
||||
err = queue.Enqueue(newDefaultMockTask())
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
assert.True(t, queue.utFull())
|
||||
err = queue.Enqueue(newDefaultMockTask())
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestDdTaskQueue(t *testing.T) {
|
||||
var err error
|
||||
var unissuedTask task
|
||||
var activeTask task
|
||||
var done bool
|
||||
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
idAllocatorIns := newMockIDAllocatorInterface()
|
||||
queue := newDdTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
assert.NotNil(t, queue)
|
||||
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
st := newDefaultMockDdlTask()
|
||||
stID := st.ID()
|
||||
stTs := st.BeginTs()
|
||||
|
||||
// no task in queue
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(stID)
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(stTs)
|
||||
assert.True(t, done)
|
||||
|
||||
// task enqueue, only one task in queue
|
||||
|
||||
err = queue.Enqueue(st)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
assert.Equal(t, 1, queue.unissuedTasks.Len())
|
||||
assert.Equal(t, 1, len(queue.utChan()))
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test active list, no task in queue
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
// test active list, no task in unissued list, only one task in active list
|
||||
|
||||
queue.AddActiveTask(unissuedTask)
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test utFull
|
||||
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
|
||||
for i := 0; i < int(queue.maxTaskNum); i++ {
|
||||
err = queue.Enqueue(newDefaultMockDdlTask())
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
assert.True(t, queue.utFull())
|
||||
err = queue.Enqueue(newDefaultMockDdlTask())
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
// test the logic of queue
|
||||
func TestDmTaskQueue_Basic(t *testing.T) {
|
||||
var err error
|
||||
var unissuedTask task
|
||||
var activeTask task
|
||||
var done bool
|
||||
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
idAllocatorIns := newMockIDAllocatorInterface()
|
||||
queue := newDmTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
assert.NotNil(t, queue)
|
||||
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
st := newDefaultMockDmlTask()
|
||||
stID := st.ID()
|
||||
stTs := st.BeginTs()
|
||||
|
||||
// no task in queue
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(stID)
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(stTs)
|
||||
assert.True(t, done)
|
||||
|
||||
// task enqueue, only one task in queue
|
||||
|
||||
err = queue.Enqueue(st)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
assert.Equal(t, 1, queue.unissuedTasks.Len())
|
||||
assert.Equal(t, 1, len(queue.utChan()))
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test active list, no task in queue
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
// test active list, no task in unissued list, only one task in active list
|
||||
|
||||
queue.AddActiveTask(unissuedTask)
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test utFull
|
||||
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
|
||||
for i := 0; i < int(queue.maxTaskNum); i++ {
|
||||
err = queue.Enqueue(newDefaultMockDmlTask())
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
assert.True(t, queue.utFull())
|
||||
err = queue.Enqueue(newDefaultMockDmlTask())
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
// test the timestamp statistics
|
||||
func TestDmTaskQueue_TimestampStatistics(t *testing.T) {
|
||||
var err error
|
||||
var unissuedTask task
|
||||
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
idAllocatorIns := newMockIDAllocatorInterface()
|
||||
queue := newDmTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
assert.NotNil(t, queue)
|
||||
|
||||
st := newDefaultMockDmlTask()
|
||||
stPChans := st.pchans
|
||||
|
||||
err = queue.Enqueue(st)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := queue.getPChanStatsInfo()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(stPChans), len(stats))
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
for _, stat := range stats {
|
||||
assert.Equal(t, unissuedTask.BeginTs(), stat.minTs)
|
||||
assert.Equal(t, unissuedTask.EndTs(), stat.maxTs)
|
||||
}
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
assert.True(t, queue.utEmpty())
|
||||
|
||||
queue.AddActiveTask(unissuedTask)
|
||||
|
||||
queue.PopActiveTask(unissuedTask.ID())
|
||||
|
||||
stats, err = queue.getPChanStatsInfo()
|
||||
assert.NoError(t, err)
|
||||
assert.Zero(t, len(stats))
|
||||
}
|
||||
|
||||
func TestDqTaskQueue(t *testing.T) {
|
||||
var err error
|
||||
var unissuedTask task
|
||||
var activeTask task
|
||||
var done bool
|
||||
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
idAllocatorIns := newMockIDAllocatorInterface()
|
||||
queue := newDqTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
assert.NotNil(t, queue)
|
||||
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
st := newDefaultMockDqlTask()
|
||||
stID := st.ID()
|
||||
stTs := st.BeginTs()
|
||||
|
||||
// no task in queue
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(stID)
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.Nil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(stTs)
|
||||
assert.True(t, done)
|
||||
|
||||
// task enqueue, only one task in queue
|
||||
|
||||
err = queue.Enqueue(st)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
assert.Equal(t, 1, queue.unissuedTasks.Len())
|
||||
assert.Equal(t, 1, len(queue.utChan()))
|
||||
|
||||
unissuedTask = queue.FrontUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
unissuedTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, unissuedTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
unissuedTask = queue.PopUnissuedTask()
|
||||
assert.NotNil(t, unissuedTask)
|
||||
assert.True(t, queue.utEmpty())
|
||||
assert.False(t, queue.utFull())
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test active list, no task in queue
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.Nil(t, activeTask)
|
||||
|
||||
// test active list, no task in unissued list, only one task in active list
|
||||
|
||||
queue.AddActiveTask(unissuedTask)
|
||||
|
||||
activeTask = queue.getTaskByReqID(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.False(t, done)
|
||||
|
||||
activeTask = queue.PopActiveTask(unissuedTask.ID())
|
||||
assert.NotNil(t, activeTask)
|
||||
|
||||
done = queue.TaskDoneTest(unissuedTask.BeginTs() + 1)
|
||||
assert.True(t, done)
|
||||
|
||||
// test utFull
|
||||
queue.maxTaskNum = 10 // not accurate, full also means utBufChan block
|
||||
for i := 0; i < int(queue.maxTaskNum); i++ {
|
||||
err = queue.Enqueue(newDefaultMockDqlTask())
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
assert.True(t, queue.utFull())
|
||||
err = queue.Enqueue(newDefaultMockDqlTask())
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestTaskScheduler(t *testing.T) {
|
||||
var err error
|
||||
|
||||
ctx := context.Background()
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
idAllocatorIns := newMockIDAllocatorInterface()
|
||||
factory := msgstream.NewSimpleMsgStreamFactory()
|
||||
|
||||
sched, err := newTaskScheduler(ctx, idAllocatorIns, tsoAllocatorIns, factory)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, sched)
|
||||
|
||||
err = sched.Start()
|
||||
assert.NoError(t, err)
|
||||
defer sched.Close()
|
||||
|
||||
assert.True(t, sched.TaskDoneTest(Timestamp(time.Now().Nanosecond())))
|
||||
|
||||
stats, err := sched.getPChanStatistics()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(stats))
|
||||
|
||||
ddNum := rand.Int() % 10
|
||||
dmNum := rand.Int() % 10
|
||||
dqNum := rand.Int() % 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for i := 0; i < ddNum; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := sched.ddQueue.Enqueue(newDefaultMockDdlTask())
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for i := 0; i < dmNum; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := sched.dmQueue.Enqueue(newDefaultMockDmlTask())
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for i := 0; i < dqNum; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := sched.dqQueue.Enqueue(newDefaultMockDqlTask())
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
|
@ -1,11 +1,32 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func genUniqueStr() string {
|
||||
l := rand.Uint64()%100 + 1
|
||||
b := make([]byte, l)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%X", b)
|
||||
}
|
||||
|
||||
func generateBoolArray(numRows int) []bool {
|
||||
ret := make([]bool, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
|
|
|
@ -20,11 +20,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
)
|
||||
|
||||
// use timestampAllocatorInterface to keep TimestampAllocator testable
|
||||
type timestampAllocatorInterface interface {
|
||||
AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error)
|
||||
}
|
||||
|
||||
type TimestampAllocator struct {
|
||||
ctx context.Context
|
||||
tso timestampAllocatorInterface
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package funcutil
|
||||
|
||||
// SetContain returns true if set m1 contains set m2
|
||||
func SetContain(m1, m2 map[interface{}]struct{}) bool {
|
||||
if len(m1) < len(m2) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k2 := range m2 {
|
||||
_, ok := m1[k2]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package funcutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSetContain(t *testing.T) {
|
||||
key1 := "key1"
|
||||
key2 := "key2"
|
||||
key3 := "key3"
|
||||
|
||||
// len(m1) < len(m2)
|
||||
m1 := make(map[interface{}]struct{})
|
||||
m2 := make(map[interface{}]struct{})
|
||||
m1[key1] = struct{}{}
|
||||
m2[key1] = struct{}{}
|
||||
m2[key2] = struct{}{}
|
||||
assert.False(t, SetContain(m1, m2))
|
||||
|
||||
// len(m1) >= len(m2), but m2 contains other key not in m1
|
||||
m1[key3] = struct{}{}
|
||||
assert.False(t, SetContain(m1, m2))
|
||||
|
||||
// m1 contains m2
|
||||
m1[key2] = struct{}{}
|
||||
assert.True(t, SetContain(m1, m2))
|
||||
}
|
Loading…
Reference in New Issue