Add unittest to proxy service

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/4973/head^2
dragondriver 2021-04-06 14:12:57 +08:00 committed by yefu.chen
parent fe3fa8dd1d
commit decc80a525
20 changed files with 711 additions and 267 deletions

View File

@ -2,6 +2,8 @@ package proxynode
import (
"strconv"
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
)
const (
@ -103,7 +105,7 @@ func CheckStrByValues(params map[string]string, key string, container []string)
return false
}
return SliceContain(container, value)
return funcutil.SliceContain(container, value)
}
type BaseConfAdapter struct {
@ -187,7 +189,7 @@ func (adapter *IVFPQConfAdapter) checkGPUPQParams(dimension, m, nbits int) bool
*/
subDim := dimension / m
return SliceContain(supportSubQuantizer, m) && SliceContain(supportDimPerSubQuantizer, subDim) && nbits == 8
return funcutil.SliceContain(supportSubQuantizer, m) && funcutil.SliceContain(supportDimPerSubQuantizer, subDim) && nbits == 8
}
func (adapter *IVFPQConfAdapter) checkCPUPQParams(dimension, m int) bool {

View File

@ -7,6 +7,8 @@ import (
"sort"
"sync"
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"go.uber.org/zap"
@ -35,7 +37,7 @@ func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []st
return channels[i] <= channels[j]
})
for loc, existedChannels := range m.insertChannels {
if m.droppedBitMap[loc] == 0 && SortedSliceEqual(existedChannels, channels) {
if m.droppedBitMap[loc] == 0 && funcutil.SortedSliceEqual(existedChannels, channels) {
m.collectionID2InsertChannels[collID] = loc
m.usageHistogram[loc]++
return nil

View File

@ -4,7 +4,6 @@ import (
"encoding/json"
"io/ioutil"
"net/http"
"reflect"
"time"
"go.uber.org/zap"
@ -43,61 +42,6 @@ func GetPulsarConfig(protocol, ip, port, url string) (map[string]interface{}, er
return ret, nil
}
func SliceContain(s interface{}, item interface{}) bool {
ss := reflect.ValueOf(s)
if ss.Kind() != reflect.Slice {
panic("SliceContain expect a slice")
}
for i := 0; i < ss.Len(); i++ {
if ss.Index(i).Interface() == item {
return true
}
}
return false
}
func SliceSetEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if !SliceContain(s2, ss1.Index(i).Interface()) {
return false
}
}
return true
}
func SortedSliceEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if ss2.Index(i).Interface() != ss1.Index(i).Interface() {
return false
}
}
return true
}
func getMax(a, b int) int {
if a > b {
return a

View File

@ -3,7 +3,6 @@ package proxynode
import (
"fmt"
"net/http"
"sort"
"strconv"
"testing"
@ -37,99 +36,3 @@ func TestGetPulsarConfig(t *testing.T) {
assert.Equal(t, fmt.Sprintf("%v", value), fmt.Sprintf("%v", runtimeConfig[key]))
}
}
func TestSliceContain(t *testing.T) {
strSlice := []string{"test", "for", "SliceContain"}
intSlice := []int{1, 2, 3}
cases := []struct {
s interface{}
item interface{}
want bool
}{
{strSlice, "test", true},
{strSlice, "for", true},
{strSlice, "SliceContain", true},
{strSlice, "tests", false},
{intSlice, 1, true},
{intSlice, 2, true},
{intSlice, 3, true},
{intSlice, 4, false},
}
for _, test := range cases {
if got := SliceContain(test.s, test.item); got != test.want {
t.Errorf("SliceContain(%v, %v) = %v", test.s, test.item, test.want)
}
}
}
func TestSliceSetEqual(t *testing.T) {
cases := []struct {
s1 interface{}
s2 interface{}
want bool
}{
{[]int{}, []int{}, true},
{[]string{}, []string{}, true},
{[]int{1, 2, 3}, []int{3, 2, 1}, true},
{[]int{1, 2, 3}, []int{1, 2, 3}, true},
{[]int{1, 2, 3}, []int{}, false},
{[]int{1, 2, 3}, []int{1, 2}, false},
{[]int{1, 2, 3}, []int{4, 5, 6}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{"SliceSetEqual", "test", "for"}, true},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for", "SliceSetEqual"}, true},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for"}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for", "SliceContain"}, false},
}
for _, test := range cases {
if got := SliceSetEqual(test.s1, test.s2); got != test.want {
t.Errorf("SliceSetEqual(%v, %v) = %v", test.s1, test.s2, test.want)
}
}
}
func TestSortedSliceEqual(t *testing.T) {
sortSlice := func(slice interface{}, less func(i, j int) bool) {
sort.Slice(slice, less)
}
intSliceAfterSort := func(slice []int) []int {
sortSlice(slice, func(i, j int) bool {
return slice[i] <= slice[j]
})
return slice
}
stringSliceAfterSort := func(slice []string) []string {
sortSlice(slice, func(i, j int) bool {
return slice[i] <= slice[j]
})
return slice
}
cases := []struct {
s1 interface{}
s2 interface{}
want bool
}{
{intSliceAfterSort([]int{}), intSliceAfterSort([]int{}), true},
{stringSliceAfterSort([]string{}), stringSliceAfterSort([]string{}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{3, 2, 1}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{1, 2, 3}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{}), false},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{1, 2}), false},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{4, 5, 6}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"SliceSetEqual", "test", "for"}), true},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), true},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for"}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for", "SliceContain"}), false},
}
for _, test := range cases {
if got := SortedSliceEqual(test.s1, test.s2); got != test.want {
t.Errorf("SliceSetEqual(%v, %v) = %v", test.s1, test.s2, test.want)
}
}
}

View File

@ -199,9 +199,9 @@ func (s *ProxyService) GetStatisticsChannel(ctx context.Context) (*milvuspb.Stri
func (s *ProxyService) RegisterLink(ctx context.Context) (*milvuspb.RegisterLinkResponse, error) {
log.Debug("register link")
t := &RegisterLinkTask{
t := &registerLinkTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
Condition: newTaskCondition(ctx),
nodeInfos: s.nodeInfos,
}
@ -237,11 +237,11 @@ func (s *ProxyService) RegisterNode(ctx context.Context, request *proxypb.Regist
zap.String("ip", request.Address.Ip),
zap.Int64("port", request.Address.Port))
t := &RegisterNodeTask{
t := &registerNodeTask{
ctx: ctx,
request: request,
startParams: s.nodeStartParams,
Condition: NewTaskCondition(ctx),
Condition: newTaskCondition(ctx),
allocator: s.allocator,
nodeInfos: s.nodeInfos,
}
@ -278,10 +278,10 @@ func (s *ProxyService) InvalidateCollectionMetaCache(ctx context.Context, reques
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName))
t := &InvalidateCollectionMetaCacheTask{
t := &invalidateCollectionMetaCacheTask{
ctx: ctx,
request: request,
Condition: NewTaskCondition(ctx),
Condition: newTaskCondition(ctx),
nodeInfos: s.nodeInfos,
}

View File

@ -7,31 +7,33 @@ import (
"strconv"
"sync"
"github.com/zilliztech/milvus-distributed/internal/util/funcutil"
grpcproxynodeclient "github.com/zilliztech/milvus-distributed/internal/distributed/proxynode/client"
"github.com/zilliztech/milvus-distributed/internal/log"
"github.com/zilliztech/milvus-distributed/internal/types"
)
type NodeInfo struct {
type nodeInfo struct {
ip string
port int64
}
type GlobalNodeInfoTable struct {
type globalNodeInfoTable struct {
mu sync.RWMutex
infos map[UniqueID]*NodeInfo
infos map[UniqueID]*nodeInfo
nodeIDs []UniqueID
// lazy creating, so len(clients) <= len(infos)
ProxyNodes map[UniqueID]types.ProxyNode
}
func (table *GlobalNodeInfoTable) randomPick() UniqueID {
func (table *globalNodeInfoTable) randomPick() UniqueID {
l := len(table.nodeIDs)
choice := rand.Intn(l)
return table.nodeIDs[choice]
}
func (table *GlobalNodeInfoTable) Pick() (*NodeInfo, error) {
func (table *globalNodeInfoTable) Pick() (*nodeInfo, error) {
table.mu.RLock()
defer table.mu.RUnlock()
@ -49,7 +51,7 @@ func (table *GlobalNodeInfoTable) Pick() (*NodeInfo, error) {
return info, nil
}
func (table *GlobalNodeInfoTable) Register(id UniqueID, info *NodeInfo) error {
func (table *globalNodeInfoTable) Register(id UniqueID, info *nodeInfo) error {
table.mu.Lock()
defer table.mu.Unlock()
@ -58,14 +60,14 @@ func (table *GlobalNodeInfoTable) Register(id UniqueID, info *NodeInfo) error {
table.infos[id] = info
}
if !SliceContain(table.nodeIDs, id) {
if !funcutil.SliceContain(table.nodeIDs, id) {
table.nodeIDs = append(table.nodeIDs, id)
}
return nil
}
func (table *GlobalNodeInfoTable) createClients() error {
func (table *globalNodeInfoTable) createClients() error {
if len(table.ProxyNodes) == len(table.infos) {
return nil
}
@ -89,7 +91,7 @@ func (table *GlobalNodeInfoTable) createClients() error {
return nil
}
func (table *GlobalNodeInfoTable) ReleaseAllClients() error {
func (table *globalNodeInfoTable) ReleaseAllClients() error {
table.mu.Lock()
log.Debug("get write lock")
defer func() {
@ -109,7 +111,7 @@ func (table *GlobalNodeInfoTable) ReleaseAllClients() error {
return nil
}
func (table *GlobalNodeInfoTable) ObtainAllClients() (map[UniqueID]types.ProxyNode, error) {
func (table *globalNodeInfoTable) ObtainAllClients() (map[UniqueID]types.ProxyNode, error) {
table.mu.RLock()
defer table.mu.RUnlock()
@ -118,10 +120,10 @@ func (table *GlobalNodeInfoTable) ObtainAllClients() (map[UniqueID]types.ProxyNo
return table.ProxyNodes, err
}
func NewGlobalNodeInfoTable() *GlobalNodeInfoTable {
return &GlobalNodeInfoTable{
func newGlobalNodeInfoTable() *globalNodeInfoTable {
return &globalNodeInfoTable{
nodeIDs: make([]UniqueID, 0),
infos: make(map[UniqueID]*NodeInfo),
infos: make(map[UniqueID]*nodeInfo),
ProxyNodes: make(map[UniqueID]types.ProxyNode),
}
}

View File

@ -0,0 +1,75 @@
package proxyservice
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGlobalNodeInfoTable_Register(t *testing.T) {
table := newGlobalNodeInfoTable()
idInfoMaps := map[UniqueID]*nodeInfo{
0: {"localhost", 1080},
1: {"localhost", 1081},
}
var err error
err = table.Register(0, idInfoMaps[0])
assert.Equal(t, nil, err)
err = table.Register(1, idInfoMaps[1])
assert.Equal(t, nil, err)
/************** duplicated register ***************/
err = table.Register(0, idInfoMaps[0])
assert.Equal(t, nil, err)
err = table.Register(1, idInfoMaps[1])
assert.Equal(t, nil, err)
}
func TestGlobalNodeInfoTable_Pick(t *testing.T) {
table := newGlobalNodeInfoTable()
var err error
_, err = table.Pick()
assert.NotEqual(t, nil, err)
idInfoMaps := map[UniqueID]*nodeInfo{
0: {"localhost", 1080},
1: {"localhost", 1081},
}
err = table.Register(0, idInfoMaps[0])
assert.Equal(t, nil, err)
err = table.Register(1, idInfoMaps[1])
assert.Equal(t, nil, err)
num := 10
for i := 0; i < num; i++ {
_, err = table.Pick()
assert.Equal(t, nil, err)
}
}
func TestGlobalNodeInfoTable_ObtainAllClients(t *testing.T) {
table := newGlobalNodeInfoTable()
var err error
clients, err := table.ObtainAllClients()
assert.Equal(t, nil, err)
assert.Equal(t, 0, len(clients))
}
func TestGlobalNodeInfoTable_ReleaseAllClients(t *testing.T) {
table := newGlobalNodeInfoTable()
err := table.ReleaseAllClients()
assert.Equal(t, nil, err)
}

View File

@ -10,17 +10,17 @@ import (
type UniqueID = typeutil.UniqueID
type Timestamp = typeutil.Timestamp
type NodeIDAllocator interface {
type nodeIDAllocator interface {
AllocOne() UniqueID
}
type NaiveNodeIDAllocator struct {
type naiveNodeIDAllocator struct {
allocator *allocator.IDAllocator
now UniqueID
mtx sync.Mutex
}
func (allocator *NaiveNodeIDAllocator) AllocOne() UniqueID {
func (allocator *naiveNodeIDAllocator) AllocOne() UniqueID {
allocator.mtx.Lock()
defer func() {
// allocator.now++
@ -29,8 +29,8 @@ func (allocator *NaiveNodeIDAllocator) AllocOne() UniqueID {
return allocator.now
}
func NewNodeIDAllocator() NodeIDAllocator {
return &NaiveNodeIDAllocator{
func newNodeIDAllocator() *naiveNodeIDAllocator {
return &naiveNodeIDAllocator{
now: 1,
}
}

View File

@ -0,0 +1,18 @@
package proxyservice
import (
"testing"
"github.com/zilliztech/milvus-distributed/internal/log"
"go.uber.org/zap"
)
func TestNaiveNodeIDAllocator_AllocOne(t *testing.T) {
allocator := newNodeIDAllocator()
num := 10
for i := 0; i < num; i++ {
nodeID := allocator.AllocOne()
log.Debug("TestNaiveNodeIDAllocator_AllocOne", zap.Any("node id", nodeID))
}
}

View File

@ -0,0 +1,20 @@
package proxyservice
import (
"testing"
"github.com/zilliztech/milvus-distributed/internal/log"
"go.uber.org/zap"
)
func TestParamTable_Init(t *testing.T) {
Params.Init()
log.Debug("TestParamTable_Init", zap.Any("PulsarAddress", Params.PulsarAddress))
log.Debug("TestParamTable_Init", zap.Any("MasterAddress", Params.MasterAddress))
log.Debug("TestParamTable_Init", zap.Any("NodeTimeTickChannel", Params.NodeTimeTickChannel))
log.Debug("TestParamTable_Init", zap.Any("ServiceTimeTickChannel", Params.ServiceTimeTickChannel))
log.Debug("TestParamTable_Init", zap.Any("DataServiceAddress", Params.DataServiceAddress))
log.Debug("TestParamTable_Init", zap.Any("InsertChannelPrefixName", Params.InsertChannelPrefixName))
log.Debug("TestParamTable_Init", zap.Any("InsertChannelNum", Params.InsertChannelNum))
}

View File

@ -15,10 +15,10 @@ import (
)
type ProxyService struct {
allocator NodeIDAllocator
sched *TaskScheduler
allocator nodeIDAllocator
sched *taskScheduler
tick *TimeTick
nodeInfos *GlobalNodeInfoTable
nodeInfos *globalNodeInfoTable
stateCode internalpb.StateCode
//subStates *internalpb.ComponentStates
@ -40,9 +40,9 @@ func NewProxyService(ctx context.Context, factory msgstream.Factory) (*ProxyServ
msFactory: factory,
}
s.allocator = NewNodeIDAllocator()
s.sched = NewTaskScheduler(ctx1)
s.nodeInfos = NewGlobalNodeInfoTable()
s.allocator = newNodeIDAllocator()
s.sched = newTaskScheduler(ctx1)
s.nodeInfos = newGlobalNodeInfoTable()
s.UpdateStateCode(internalpb.StateCode_Abnormal)
log.Debug("proxyservice", zap.Any("state of proxyservice: ", internalpb.StateCode_Abnormal))

View File

@ -40,12 +40,12 @@ type Condition interface {
Notify(err error)
}
type TaskCondition struct {
type taskCondition struct {
done chan error
ctx context.Context
}
func (c *TaskCondition) WaitToFinish() error {
func (c *taskCondition) WaitToFinish() error {
select {
case <-c.ctx.Done():
return errors.New("timeout")
@ -54,41 +54,41 @@ func (c *TaskCondition) WaitToFinish() error {
}
}
func (c *TaskCondition) Notify(err error) {
func (c *taskCondition) Notify(err error) {
c.done <- err
}
func NewTaskCondition(ctx context.Context) Condition {
return &TaskCondition{
func newTaskCondition(ctx context.Context) Condition {
return &taskCondition{
done: make(chan error),
ctx: ctx,
}
}
type RegisterLinkTask struct {
type registerLinkTask struct {
Condition
ctx context.Context
response *milvuspb.RegisterLinkResponse
nodeInfos *GlobalNodeInfoTable
nodeInfos *globalNodeInfoTable
}
func (t *RegisterLinkTask) Ctx() context.Context {
func (t *registerLinkTask) Ctx() context.Context {
return t.ctx
}
func (t *RegisterLinkTask) ID() UniqueID {
func (t *registerLinkTask) ID() UniqueID {
return 0
}
func (t *RegisterLinkTask) Name() string {
func (t *registerLinkTask) Name() string {
return RegisterLinkTaskName
}
func (t *RegisterLinkTask) PreExecute(ctx context.Context) error {
func (t *registerLinkTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *RegisterLinkTask) Execute(ctx context.Context) error {
func (t *registerLinkTask) Execute(ctx context.Context) error {
info, err := t.nodeInfos.Pick()
if err != nil {
return err
@ -106,39 +106,39 @@ func (t *RegisterLinkTask) Execute(ctx context.Context) error {
return nil
}
func (t *RegisterLinkTask) PostExecute(ctx context.Context) error {
func (t *registerLinkTask) PostExecute(ctx context.Context) error {
return nil
}
type RegisterNodeTask struct {
type registerNodeTask struct {
Condition
ctx context.Context
request *proxypb.RegisterNodeRequest
response *proxypb.RegisterNodeResponse
startParams []*commonpb.KeyValuePair
allocator NodeIDAllocator
nodeInfos *GlobalNodeInfoTable
allocator nodeIDAllocator
nodeInfos *globalNodeInfoTable
}
func (t *RegisterNodeTask) Ctx() context.Context {
func (t *registerNodeTask) Ctx() context.Context {
return t.ctx
}
func (t *RegisterNodeTask) ID() UniqueID {
func (t *registerNodeTask) ID() UniqueID {
return t.request.Base.MsgID
}
func (t *RegisterNodeTask) Name() string {
func (t *registerNodeTask) Name() string {
return RegisterNodeTaskName
}
func (t *RegisterNodeTask) PreExecute(ctx context.Context) error {
func (t *registerNodeTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *RegisterNodeTask) Execute(ctx context.Context) error {
func (t *registerNodeTask) Execute(ctx context.Context) error {
nodeID := t.allocator.AllocOne()
info := NodeInfo{
info := nodeInfo{
ip: t.request.Address.Ip,
port: t.request.Address.Port,
}
@ -157,35 +157,35 @@ func (t *RegisterNodeTask) Execute(ctx context.Context) error {
return err
}
func (t *RegisterNodeTask) PostExecute(ctx context.Context) error {
func (t *registerNodeTask) PostExecute(ctx context.Context) error {
return nil
}
type InvalidateCollectionMetaCacheTask struct {
type invalidateCollectionMetaCacheTask struct {
Condition
ctx context.Context
request *proxypb.InvalidateCollMetaCacheRequest
response *commonpb.Status
nodeInfos *GlobalNodeInfoTable
nodeInfos *globalNodeInfoTable
}
func (t *InvalidateCollectionMetaCacheTask) Ctx() context.Context {
func (t *invalidateCollectionMetaCacheTask) Ctx() context.Context {
return t.ctx
}
func (t *InvalidateCollectionMetaCacheTask) ID() UniqueID {
func (t *invalidateCollectionMetaCacheTask) ID() UniqueID {
return t.request.Base.MsgID
}
func (t *InvalidateCollectionMetaCacheTask) Name() string {
func (t *invalidateCollectionMetaCacheTask) Name() string {
return InvalidateCollectionMetaCacheTaskName
}
func (t *InvalidateCollectionMetaCacheTask) PreExecute(ctx context.Context) error {
func (t *invalidateCollectionMetaCacheTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *InvalidateCollectionMetaCacheTask) Execute(ctx context.Context) error {
func (t *invalidateCollectionMetaCacheTask) Execute(ctx context.Context) error {
var err error
clients, err := t.nodeInfos.ObtainAllClients()
if err != nil {
@ -206,6 +206,6 @@ func (t *InvalidateCollectionMetaCacheTask) Execute(ctx context.Context) error {
return nil
}
func (t *InvalidateCollectionMetaCacheTask) PostExecute(ctx context.Context) error {
func (t *invalidateCollectionMetaCacheTask) PostExecute(ctx context.Context) error {
return nil
}

View File

@ -0,0 +1,90 @@
package proxyservice
import "context"
type mockTask struct {
ctx context.Context
id UniqueID
name string
}
func (t *mockTask) Ctx() context.Context {
return t.ctx
}
func (t *mockTask) ID() UniqueID {
return t.id
}
func (t *mockTask) Name() string {
return t.name
}
func (t *mockTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *mockTask) Execute(ctx context.Context) error {
return nil
}
func (t *mockTask) PostExecute(ctx context.Context) error {
return nil
}
func (t *mockTask) WaitToFinish() error {
return nil
}
func (t *mockTask) Notify(err error) {
}
func newMockTask(ctx context.Context) *mockTask {
return &mockTask{
ctx: ctx,
id: 0,
name: "mockTask",
}
}
type mockRegisterLinkTask struct {
mockTask
}
type mockRegisterNodeTask struct {
mockTask
}
type mockInvalidateCollectionMetaCacheTask struct {
mockTask
}
func newMockRegisterLinkTask(ctx context.Context) *mockRegisterLinkTask {
return &mockRegisterLinkTask{
mockTask: mockTask{
ctx: ctx,
id: 0,
name: "mockRegisterLinkTask",
},
}
}
func newMockRegisterNodeTask(ctx context.Context) *mockRegisterNodeTask {
return &mockRegisterNodeTask{
mockTask: mockTask{
ctx: ctx,
id: 0,
name: "mockRegisterNodeTask",
},
}
}
func newMockInvalidateCollectionMetaCacheTask(ctx context.Context) *mockInvalidateCollectionMetaCacheTask {
return &mockInvalidateCollectionMetaCacheTask{
mockTask: mockTask{
ctx: ctx,
id: 0,
name: "mockInvalidateCollectionMetaCacheTask",
},
}
}

View File

@ -8,7 +8,7 @@ import (
"github.com/zilliztech/milvus-distributed/internal/log"
)
type TaskQueue interface {
type taskQueue interface {
Chan() <-chan int
Empty() bool
Full() bool
@ -18,7 +18,7 @@ type TaskQueue interface {
Enqueue(t task) error
}
type BaseTaskQueue struct {
type baseTaskQueue struct {
tasks *list.List
mtx sync.Mutex
@ -28,19 +28,19 @@ type BaseTaskQueue struct {
bufChan chan int // to block scheduler
}
func (queue *BaseTaskQueue) Chan() <-chan int {
func (queue *baseTaskQueue) Chan() <-chan int {
return queue.bufChan
}
func (queue *BaseTaskQueue) Empty() bool {
func (queue *baseTaskQueue) Empty() bool {
return queue.tasks.Len() <= 0
}
func (queue *BaseTaskQueue) Full() bool {
func (queue *baseTaskQueue) Full() bool {
return int64(queue.tasks.Len()) >= queue.maxTaskNum
}
func (queue *BaseTaskQueue) addTask(t task) error {
func (queue *baseTaskQueue) addTask(t task) error {
queue.mtx.Lock()
defer queue.mtx.Unlock()
@ -52,7 +52,7 @@ func (queue *BaseTaskQueue) addTask(t task) error {
return nil
}
func (queue *BaseTaskQueue) FrontTask() task {
func (queue *baseTaskQueue) FrontTask() task {
queue.mtx.Lock()
defer queue.mtx.Unlock()
@ -64,7 +64,7 @@ func (queue *BaseTaskQueue) FrontTask() task {
return queue.tasks.Front().Value.(task)
}
func (queue *BaseTaskQueue) PopTask() task {
func (queue *baseTaskQueue) PopTask() task {
queue.mtx.Lock()
defer queue.mtx.Unlock()
@ -79,12 +79,12 @@ func (queue *BaseTaskQueue) PopTask() task {
return ft.Value.(task)
}
func (queue *BaseTaskQueue) Enqueue(t task) error {
func (queue *baseTaskQueue) Enqueue(t task) error {
return queue.addTask(t)
}
func NewBaseTaskQueue() TaskQueue {
return &BaseTaskQueue{
func newBaseTaskQueue() *baseTaskQueue {
return &baseTaskQueue{
tasks: list.New(),
maxTaskNum: 1024,
bufChan: make(chan int, 1024),

View File

@ -0,0 +1,149 @@
package proxyservice
import (
"context"
"sync"
"testing"
"go.uber.org/zap"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/log"
)
func TestBaseTaskQueue_Enqueue(t *testing.T) {
queue := newBaseTaskQueue()
num := 10
var wg sync.WaitGroup
for i := 0; i < num; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tsk := newMockTask(context.Background())
err := queue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
}
wg.Wait()
}
func TestBaseTaskQueue_FrontTask(t *testing.T) {
queue := newBaseTaskQueue()
tsk := queue.FrontTask()
assert.Equal(t, nil, tsk)
frontTask := newMockTask(context.Background())
err := queue.Enqueue(frontTask)
assert.Equal(t, nil, err)
tsk = queue.FrontTask()
assert.NotEqual(t, nil, tsk)
assert.Equal(t, frontTask.ID(), tsk.ID())
assert.Equal(t, frontTask.Name(), tsk.Name())
num := 10
for i := 0; i < num; i++ {
tsk := newMockTask(context.Background())
err := queue.Enqueue(tsk)
assert.Equal(t, nil, err)
tskF := queue.FrontTask()
assert.NotEqual(t, nil, tskF)
assert.Equal(t, frontTask.ID(), tskF.ID())
assert.Equal(t, frontTask.Name(), tskF.Name())
}
}
func TestBaseTaskQueue_PopTask(t *testing.T) {
queue := newBaseTaskQueue()
tsk := queue.PopTask()
assert.Equal(t, nil, tsk)
num := 10
for i := 0; i < num; i++ {
tsk := newMockTask(context.Background())
err := queue.Enqueue(tsk)
assert.Equal(t, nil, err)
tskP := queue.PopTask()
assert.NotEqual(t, nil, tskP)
}
tsk = queue.PopTask()
assert.Equal(t, nil, tsk)
}
func TestBaseTaskQueue_Chan(t *testing.T) {
queue := newBaseTaskQueue()
ctx, cancel := context.WithCancel(context.Background())
go func() {
for {
select {
case <-ctx.Done():
log.Debug("TestBaseTaskQueue_Chan exit")
return
case i := <-queue.Chan():
log.Debug("TestBaseTaskQueue_Chan", zap.Any("receive", i))
}
}
}()
num := 10
var wg sync.WaitGroup
for i := 0; i < num; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tsk := newMockTask(context.Background())
err := queue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
}
wg.Wait()
cancel()
}
func TestBaseTaskQueue_Empty(t *testing.T) {
queue := newBaseTaskQueue()
assert.Equal(t, true, queue.Empty())
num := 10
for i := 0; i < num; i++ {
tsk := newMockTask(context.Background())
err := queue.Enqueue(tsk)
assert.Equal(t, nil, err)
assert.Equal(t, false, queue.Empty())
}
for !queue.Empty() {
assert.Equal(t, false, queue.Empty())
queue.PopTask()
}
assert.Equal(t, true, queue.Empty())
}
func TestBaseTaskQueue_Full(t *testing.T) {
queue := newBaseTaskQueue()
for !queue.Full() {
assert.Equal(t, false, queue.Full())
tsk := newMockTask(context.Background())
err := queue.Enqueue(tsk)
assert.Equal(t, nil, err)
}
assert.Equal(t, true, queue.Full())
}

View File

@ -9,41 +9,41 @@ import (
"github.com/zilliztech/milvus-distributed/internal/util/trace"
)
type TaskScheduler struct {
RegisterLinkTaskQueue TaskQueue
RegisterNodeTaskQueue TaskQueue
InvalidateCollectionMetaCacheTaskQueue TaskQueue
type taskScheduler struct {
RegisterLinkTaskQueue taskQueue
RegisterNodeTaskQueue taskQueue
InvalidateCollectionMetaCacheTaskQueue taskQueue
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func NewTaskScheduler(ctx context.Context) *TaskScheduler {
func newTaskScheduler(ctx context.Context) *taskScheduler {
ctx1, cancel := context.WithCancel(ctx)
return &TaskScheduler{
RegisterLinkTaskQueue: NewBaseTaskQueue(),
RegisterNodeTaskQueue: NewBaseTaskQueue(),
InvalidateCollectionMetaCacheTaskQueue: NewBaseTaskQueue(),
return &taskScheduler{
RegisterLinkTaskQueue: newBaseTaskQueue(),
RegisterNodeTaskQueue: newBaseTaskQueue(),
InvalidateCollectionMetaCacheTaskQueue: newBaseTaskQueue(),
ctx: ctx1,
cancel: cancel,
}
}
func (sched *TaskScheduler) scheduleRegisterLinkTask() task {
func (sched *taskScheduler) scheduleRegisterLinkTask() task {
return sched.RegisterLinkTaskQueue.PopTask()
}
func (sched *TaskScheduler) scheduleRegisterNodeTask() task {
func (sched *taskScheduler) scheduleRegisterNodeTask() task {
return sched.RegisterNodeTaskQueue.PopTask()
}
func (sched *TaskScheduler) scheduleInvalidateCollectionMetaCacheTask() task {
func (sched *taskScheduler) scheduleInvalidateCollectionMetaCacheTask() task {
return sched.InvalidateCollectionMetaCacheTaskQueue.PopTask()
}
func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
func (sched *taskScheduler) processTask(t task, q taskQueue) {
span, ctx := trace.StartSpanFromContext(t.Ctx(),
opentracing.Tags{
"Type": t.Name(),
@ -70,7 +70,7 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
err = t.PostExecute(ctx)
}
func (sched *TaskScheduler) registerLinkLoop() {
func (sched *taskScheduler) registerLinkLoop() {
defer sched.wg.Done()
for {
select {
@ -85,7 +85,7 @@ func (sched *TaskScheduler) registerLinkLoop() {
}
}
func (sched *TaskScheduler) registerNodeLoop() {
func (sched *taskScheduler) registerNodeLoop() {
defer sched.wg.Done()
for {
select {
@ -100,7 +100,7 @@ func (sched *TaskScheduler) registerNodeLoop() {
}
}
func (sched *TaskScheduler) invalidateCollectionMetaCacheLoop() {
func (sched *taskScheduler) invalidateCollectionMetaCacheLoop() {
defer sched.wg.Done()
for {
select {
@ -115,7 +115,7 @@ func (sched *TaskScheduler) invalidateCollectionMetaCacheLoop() {
}
}
func (sched *TaskScheduler) Start() {
func (sched *taskScheduler) Start() {
sched.wg.Add(1)
go sched.registerLinkLoop()
@ -126,7 +126,7 @@ func (sched *TaskScheduler) Start() {
go sched.invalidateCollectionMetaCacheLoop()
}
func (sched *TaskScheduler) Close() {
func (sched *taskScheduler) Close() {
sched.cancel()
sched.wg.Wait()
}

View File

@ -0,0 +1,100 @@
package proxyservice
import (
"context"
"math/rand"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTaskScheduler_Start(t *testing.T) {
sched := newTaskScheduler(context.Background())
sched.Start()
defer sched.Close()
num := 64
var wg sync.WaitGroup
for i := 0; i < num; i++ {
wg.Add(1)
switch rand.Int() % 3 {
case 0:
go func() {
defer wg.Done()
tsk := newMockRegisterLinkTask(context.Background())
err := sched.RegisterLinkTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
case 1:
go func() {
defer wg.Done()
tsk := newMockRegisterNodeTask(context.Background())
err := sched.RegisterNodeTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
case 2:
go func() {
defer wg.Done()
tsk := newMockInvalidateCollectionMetaCacheTask(context.Background())
err := sched.InvalidateCollectionMetaCacheTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
default:
go func() {
defer wg.Done()
tsk := newMockRegisterLinkTask(context.Background())
err := sched.RegisterLinkTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
}
}
wg.Wait()
time.Sleep(3 * time.Second)
}
func TestTaskScheduler_Close(t *testing.T) {
sched := newTaskScheduler(context.Background())
sched.Start()
defer sched.Close()
num := 64
var wg sync.WaitGroup
for i := 0; i < num; i++ {
wg.Add(1)
switch rand.Int() % 3 {
case 0:
go func() {
defer wg.Done()
tsk := newMockRegisterLinkTask(context.Background())
err := sched.RegisterLinkTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
case 1:
go func() {
defer wg.Done()
tsk := newMockRegisterNodeTask(context.Background())
err := sched.RegisterNodeTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
case 2:
go func() {
defer wg.Done()
tsk := newMockInvalidateCollectionMetaCacheTask(context.Background())
err := sched.InvalidateCollectionMetaCacheTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
default:
go func() {
defer wg.Done()
tsk := newMockRegisterLinkTask(context.Background())
err := sched.RegisterLinkTaskQueue.Enqueue(tsk)
assert.Equal(t, nil, err)
}()
}
}
wg.Wait()
}

View File

@ -1,21 +0,0 @@
package proxyservice
import (
"reflect"
)
// what if golang support generic programming
func SliceContain(s interface{}, item interface{}) bool {
ss := reflect.ValueOf(s)
if ss.Kind() != reflect.Slice {
panic("SliceContain expect a slice")
}
for i := 0; i < ss.Len(); i++ {
if ss.Index(i).Interface() == item {
return true
}
}
return false
}

View File

@ -0,0 +1,58 @@
package funcutil
import "reflect"
func SliceContain(s interface{}, item interface{}) bool {
ss := reflect.ValueOf(s)
if ss.Kind() != reflect.Slice {
panic("SliceContain expect a slice")
}
for i := 0; i < ss.Len(); i++ {
if ss.Index(i).Interface() == item {
return true
}
}
return false
}
func SliceSetEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if !SliceContain(s2, ss1.Index(i).Interface()) {
return false
}
}
return true
}
func SortedSliceEqual(s1 interface{}, s2 interface{}) bool {
ss1 := reflect.ValueOf(s1)
ss2 := reflect.ValueOf(s2)
if ss1.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss2.Kind() != reflect.Slice {
panic("expect a slice")
}
if ss1.Len() != ss2.Len() {
return false
}
for i := 0; i < ss1.Len(); i++ {
if ss2.Index(i).Interface() != ss1.Index(i).Interface() {
return false
}
}
return true
}

View File

@ -0,0 +1,102 @@
package funcutil
import (
"sort"
"testing"
)
func TestSliceContain(t *testing.T) {
strSlice := []string{"test", "for", "SliceContain"}
intSlice := []int{1, 2, 3}
cases := []struct {
s interface{}
item interface{}
want bool
}{
{strSlice, "test", true},
{strSlice, "for", true},
{strSlice, "SliceContain", true},
{strSlice, "tests", false},
{intSlice, 1, true},
{intSlice, 2, true},
{intSlice, 3, true},
{intSlice, 4, false},
}
for _, test := range cases {
if got := SliceContain(test.s, test.item); got != test.want {
t.Errorf("SliceContain(%v, %v) = %v", test.s, test.item, test.want)
}
}
}
func TestSliceSetEqual(t *testing.T) {
cases := []struct {
s1 interface{}
s2 interface{}
want bool
}{
{[]int{}, []int{}, true},
{[]string{}, []string{}, true},
{[]int{1, 2, 3}, []int{3, 2, 1}, true},
{[]int{1, 2, 3}, []int{1, 2, 3}, true},
{[]int{1, 2, 3}, []int{}, false},
{[]int{1, 2, 3}, []int{1, 2}, false},
{[]int{1, 2, 3}, []int{4, 5, 6}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{"SliceSetEqual", "test", "for"}, true},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for", "SliceSetEqual"}, true},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for"}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{}, false},
{[]string{"test", "for", "SliceSetEqual"}, []string{"test", "for", "SliceContain"}, false},
}
for _, test := range cases {
if got := SliceSetEqual(test.s1, test.s2); got != test.want {
t.Errorf("SliceSetEqual(%v, %v) = %v", test.s1, test.s2, test.want)
}
}
}
func TestSortedSliceEqual(t *testing.T) {
sortSlice := func(slice interface{}, less func(i, j int) bool) {
sort.Slice(slice, less)
}
intSliceAfterSort := func(slice []int) []int {
sortSlice(slice, func(i, j int) bool {
return slice[i] <= slice[j]
})
return slice
}
stringSliceAfterSort := func(slice []string) []string {
sortSlice(slice, func(i, j int) bool {
return slice[i] <= slice[j]
})
return slice
}
cases := []struct {
s1 interface{}
s2 interface{}
want bool
}{
{intSliceAfterSort([]int{}), intSliceAfterSort([]int{}), true},
{stringSliceAfterSort([]string{}), stringSliceAfterSort([]string{}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{3, 2, 1}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{1, 2, 3}), true},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{}), false},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{1, 2}), false},
{intSliceAfterSort([]int{1, 2, 3}), intSliceAfterSort([]int{4, 5, 6}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"SliceSetEqual", "test", "for"}), true},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), true},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for"}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{}), false},
{stringSliceAfterSort([]string{"test", "for", "SliceSetEqual"}), stringSliceAfterSort([]string{"test", "for", "SliceContain"}), false},
}
for _, test := range cases {
if got := SortedSliceEqual(test.s1, test.s2); got != test.want {
t.Errorf("SliceSetEqual(%v, %v) = %v", test.s1, test.s2, test.want)
}
}
}