Add ut for querycoord's meta (#7968)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/8542/head
xige-16 2021-09-29 09:56:04 +08:00 committed by GitHub
parent d6f96ec9fc
commit 9400694867
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 340 additions and 44 deletions

View File

@ -165,26 +165,28 @@ func TestGrpcRequest(t *testing.T) {
})
t.Run("Test AddQueryChannel", func(t *testing.T) {
reqChannel, resChannel := cluster.clusterMeta.GetQueryChannel(defaultCollectionID)
reqChannel, resChannel, err := cluster.clusterMeta.GetQueryChannel(defaultCollectionID)
assert.Nil(t, err)
addQueryChannelReq := &querypb.AddQueryChannelRequest{
NodeID: nodeID,
CollectionID: defaultCollectionID,
RequestChannelID: reqChannel,
ResultChannelID: resChannel,
}
err := cluster.addQueryChannel(baseCtx, nodeID, addQueryChannelReq)
err = cluster.addQueryChannel(baseCtx, nodeID, addQueryChannelReq)
assert.Nil(t, err)
})
t.Run("Test RemoveQueryChannel", func(t *testing.T) {
reqChannel, resChannel := cluster.clusterMeta.GetQueryChannel(defaultCollectionID)
reqChannel, resChannel, err := cluster.clusterMeta.GetQueryChannel(defaultCollectionID)
assert.Nil(t, err)
removeQueryChannelReq := &querypb.RemoveQueryChannelRequest{
NodeID: nodeID,
CollectionID: defaultCollectionID,
RequestChannelID: reqChannel,
ResultChannelID: resChannel,
}
err := cluster.removeQueryChannel(baseCtx, nodeID, removeQueryChannelReq)
err = cluster.removeQueryChannel(baseCtx, nodeID, removeQueryChannelReq)
assert.Nil(t, err)
})

View File

@ -429,7 +429,15 @@ func (qc *QueryCoord) CreateQueryChannel(ctx context.Context, req *querypb.Creat
}
collectionID := req.CollectionID
queryChannel, queryResultChannel := qc.meta.GetQueryChannel(collectionID)
queryChannel, queryResultChannel, err := qc.meta.GetQueryChannel(collectionID)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
log.Debug("createQueryChannel end with error")
return &querypb.CreateQueryChannelResponse{
Status: status,
}, err
}
return &querypb.CreateQueryChannelResponse{
Status: status,

View File

@ -553,3 +553,22 @@ func TestGrpcTaskBeforeHealthy(t *testing.T) {
err = removeAllSession()
assert.Nil(t, err)
}
func Test_GrpcGetQueryChannelFail(t *testing.T) {
kv := &testKv{
returnFn: failedResult,
}
meta, err := newMeta(kv)
assert.Nil(t, err)
queryCoord := &QueryCoord{
meta: meta,
}
queryCoord.stateCode.Store(internalpb.StateCode_Healthy)
res, err := queryCoord.CreateQueryChannel(context.Background(), &querypb.CreateQueryChannelRequest{
CollectionID: defaultCollectionID,
})
assert.NotNil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, res.Status.ErrorCode)
}

View File

@ -21,7 +21,7 @@ import (
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/kv"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
@ -63,16 +63,16 @@ type Meta interface {
removeDmChannel(collectionID UniqueID, nodeID int64, channels []string) error
getQueryChannelInfoByID(collectionID UniqueID) (*querypb.QueryChannelInfo, error)
GetQueryChannel(collectionID UniqueID) (string, string)
GetQueryChannel(collectionID UniqueID) (string, string, error)
setLoadType(collectionID UniqueID, loadType querypb.LoadType) error
getLoadType(collectionID UniqueID) (querypb.LoadType, error)
setLoadPercentage(collectionID UniqueID, partitionID UniqueID, percentage int64, loadType querypb.LoadType) error
printMeta()
//printMeta()
}
type MetaReplica struct {
client *etcdkv.EtcdKV // client of a reliable kv service, i.e. etcd client
client kv.MetaKv // client of a reliable kv service, i.e. etcd client
sync.RWMutex
collectionInfos map[UniqueID]*querypb.CollectionInfo
@ -82,7 +82,7 @@ type MetaReplica struct {
//partitionStates map[UniqueID]*querypb.PartitionStates
}
func newMeta(kv *etcdkv.EtcdKV) (Meta, error) {
func newMeta(kv kv.MetaKv) (Meta, error) {
collectionInfos := make(map[UniqueID]*querypb.CollectionInfo)
segmentInfos := make(map[UniqueID]*querypb.SegmentInfo)
queryChannelInfos := make(map[UniqueID]*querypb.QueryChannelInfo)
@ -579,14 +579,14 @@ func (m *MetaReplica) removeDmChannel(collectionID UniqueID, nodeID int64, chann
return errors.New("addDmChannels: can't find collection in collectionInfos")
}
func (m *MetaReplica) GetQueryChannel(collectionID UniqueID) (string, string) {
func (m *MetaReplica) GetQueryChannel(collectionID UniqueID) (string, string, error) {
m.Lock()
defer m.Unlock()
//TODO::to remove
collectionID = 0
if info, ok := m.queryChannelInfos[collectionID]; ok {
return info.QueryChannelID, info.QueryResultChannelID
return info.QueryChannelID, info.QueryResultChannelID, nil
}
searchPrefix := Params.SearchChannelPrefix
@ -600,9 +600,14 @@ func (m *MetaReplica) GetQueryChannel(collectionID UniqueID) (string, string) {
QueryChannelID: allocatedQueryChannel,
QueryResultChannelID: allocatedQueryResultChannel,
}
err := saveQueryChannelInfo(collectionID, queryChannelInfo, m.client)
if err != nil {
log.Error("GetQueryChannel: save channel to etcd error", zap.Error(err))
return "", "", err
}
m.queryChannelInfos[collectionID] = queryChannelInfo
//TODO::return channel according collectionID
return allocatedQueryChannel, allocatedQueryResultChannel
return allocatedQueryChannel, allocatedQueryResultChannel, nil
}
func (m *MetaReplica) setLoadType(collectionID UniqueID, loadType querypb.LoadType) error {
@ -680,54 +685,54 @@ func (m *MetaReplica) setLoadPercentage(collectionID UniqueID, partitionID Uniqu
return nil
}
func (m *MetaReplica) printMeta() {
m.RLock()
defer m.RUnlock()
for id, info := range m.collectionInfos {
log.Debug("query coordinator MetaReplica: collectionInfo", zap.Int64("collectionID", id), zap.Any("info", info))
}
//func (m *MetaReplica) printMeta() {
// m.RLock()
// defer m.RUnlock()
// for id, info := range m.collectionInfos {
// log.Debug("query coordinator MetaReplica: collectionInfo", zap.Int64("collectionID", id), zap.Any("info", info))
// }
//
// for id, info := range m.segmentInfos {
// log.Debug("query coordinator MetaReplica: segmentInfo", zap.Int64("segmentID", id), zap.Any("info", info))
// }
//
// for id, info := range m.queryChannelInfos {
// log.Debug("query coordinator MetaReplica: queryChannelInfo", zap.Int64("collectionID", id), zap.Any("info", info))
// }
//}
for id, info := range m.segmentInfos {
log.Debug("query coordinator MetaReplica: segmentInfo", zap.Int64("segmentID", id), zap.Any("info", info))
}
for id, info := range m.queryChannelInfos {
log.Debug("query coordinator MetaReplica: queryChannelInfo", zap.Int64("collectionID", id), zap.Any("info", info))
}
}
func saveGlobalCollectionInfo(collectionID UniqueID, info *querypb.CollectionInfo, kv *etcdkv.EtcdKV) error {
func saveGlobalCollectionInfo(collectionID UniqueID, info *querypb.CollectionInfo, kv kv.MetaKv) error {
infoBytes := proto.MarshalTextString(info)
key := fmt.Sprintf("%s/%d", collectionMetaPrefix, collectionID)
return kv.Save(key, infoBytes)
}
func removeGlobalCollectionInfo(collectionID UniqueID, kv *etcdkv.EtcdKV) error {
func removeGlobalCollectionInfo(collectionID UniqueID, kv kv.MetaKv) error {
key := fmt.Sprintf("%s/%d", collectionMetaPrefix, collectionID)
return kv.Remove(key)
}
func saveSegmentInfo(segmentID UniqueID, info *querypb.SegmentInfo, kv *etcdkv.EtcdKV) error {
func saveSegmentInfo(segmentID UniqueID, info *querypb.SegmentInfo, kv kv.MetaKv) error {
infoBytes := proto.MarshalTextString(info)
key := fmt.Sprintf("%s/%d", segmentMetaPrefix, segmentID)
return kv.Save(key, infoBytes)
}
func removeSegmentInfo(segmentID UniqueID, kv *etcdkv.EtcdKV) error {
func removeSegmentInfo(segmentID UniqueID, kv kv.MetaKv) error {
key := fmt.Sprintf("%s/%d", segmentMetaPrefix, segmentID)
return kv.Remove(key)
}
func saveQueryChannelInfo(collectionID UniqueID, info *querypb.QueryChannelInfo, kv *etcdkv.EtcdKV) error {
func saveQueryChannelInfo(collectionID UniqueID, info *querypb.QueryChannelInfo, kv kv.MetaKv) error {
infoBytes := proto.MarshalTextString(info)
key := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, collectionID)
return kv.Save(key, infoBytes)
}
func removeQueryChannelInfo(collectionID UniqueID, kv *etcdkv.EtcdKV) error {
key := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, collectionID)
return kv.Remove(key)
}
//func removeQueryChannelInfo(collectionID UniqueID, kv *etcdkv.EtcdKV) error {
// key := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, collectionID)
// return kv.Remove(key)
//}

View File

@ -12,6 +12,7 @@
package querycoord
import (
"errors"
"fmt"
"testing"
@ -19,10 +20,31 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
func successResult() error { return nil }
func failedResult() error { return errors.New("") }
type testKv struct {
kv.MetaKv
returnFn func() error
}
func (tk *testKv) Save(key, value string) error {
return tk.returnFn()
}
func (tk *testKv) Remove(key string) error {
return tk.returnFn()
}
func (tk *testKv) LoadWithPrefix(key string) ([]string, []string, error) {
return nil, nil, nil
}
func TestReplica_Release(t *testing.T) {
refreshParams()
etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
@ -53,6 +75,220 @@ func TestReplica_Release(t *testing.T) {
meta.releaseCollection(1)
}
func TestMetaFunc(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
meta := &MetaReplica{
client: kv,
collectionInfos: map[UniqueID]*querypb.CollectionInfo{},
segmentInfos: map[UniqueID]*querypb.SegmentInfo{},
queryChannelInfos: map[UniqueID]*querypb.QueryChannelInfo{},
}
nodeID := int64(100)
dmChannels := []string{"testDm1", "testDm2"}
t.Run("Test ShowPartitionFail", func(t *testing.T) {
res, err := meta.showPartitions(defaultCollectionID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test HasCollectionFalse", func(t *testing.T) {
hasCollection := meta.hasCollection(defaultCollectionID)
assert.Equal(t, false, hasCollection)
})
t.Run("Test HasPartitionFalse", func(t *testing.T) {
hasPartition := meta.hasPartition(defaultCollectionID, defaultPartitionID)
assert.Equal(t, false, hasPartition)
})
t.Run("Test HasReleasePartitionFalse", func(t *testing.T) {
hasReleasePartition := meta.hasReleasePartition(defaultCollectionID, defaultPartitionID)
assert.Equal(t, false, hasReleasePartition)
})
t.Run("Test HasSegmentInfoFalse", func(t *testing.T) {
hasSegmentInfo := meta.hasSegmentInfo(defaultSegmentID)
assert.Equal(t, false, hasSegmentInfo)
})
t.Run("Test GetSegmentInfoByIDFail", func(t *testing.T) {
res, err := meta.getSegmentInfoByID(defaultSegmentID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test GetCollectionInfoByIDFail", func(t *testing.T) {
res, err := meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, res)
assert.NotNil(t, err)
})
t.Run("Test GetQueryChannelInfoByIDFail", func(t *testing.T) {
res, err := meta.getQueryChannelInfoByID(defaultCollectionID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test GetPartitionStatesByIDFail", func(t *testing.T) {
res, err := meta.getPartitionStatesByID(defaultCollectionID, defaultPartitionID)
assert.Nil(t, res)
assert.NotNil(t, err)
})
t.Run("Test GetDmChannelsByNodeIDFail", func(t *testing.T) {
res, err := meta.getDmChannelsByNodeID(defaultCollectionID, nodeID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test AddDmChannelFail", func(t *testing.T) {
err := meta.addDmChannel(defaultCollectionID, nodeID, dmChannels)
assert.NotNil(t, err)
})
t.Run("Test SetLoadTypeFail", func(t *testing.T) {
err := meta.setLoadType(defaultCollectionID, querypb.LoadType_loadCollection)
assert.NotNil(t, err)
})
t.Run("Test SetLoadPercentageFail", func(t *testing.T) {
err := meta.setLoadPercentage(defaultCollectionID, defaultPartitionID, 100, querypb.LoadType_loadCollection)
assert.NotNil(t, err)
})
t.Run("Test AddCollection", func(t *testing.T) {
schema := genCollectionSchema(defaultCollectionID, false)
err := meta.addCollection(defaultCollectionID, schema)
assert.Nil(t, err)
})
t.Run("Test HasCollection", func(t *testing.T) {
hasCollection := meta.hasCollection(defaultCollectionID)
assert.Equal(t, true, hasCollection)
})
t.Run("Test AddPartition", func(t *testing.T) {
err := meta.addPartition(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
})
t.Run("Test HasPartition", func(t *testing.T) {
hasPartition := meta.hasPartition(defaultCollectionID, defaultPartitionID)
assert.Equal(t, true, hasPartition)
})
t.Run("Test ShowCollections", func(t *testing.T) {
info := meta.showCollections()
assert.Equal(t, 1, len(info))
})
t.Run("Test ShowPartitions", func(t *testing.T) {
states, err := meta.showPartitions(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, 1, len(states))
})
t.Run("Test GetCollectionInfoByID", func(t *testing.T) {
info, err := meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, defaultCollectionID, info.CollectionID)
})
t.Run("Test GetPartitionStatesByID", func(t *testing.T) {
state, err := meta.getPartitionStatesByID(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
assert.Equal(t, defaultPartitionID, state.PartitionID)
})
t.Run("Test AddDmChannel", func(t *testing.T) {
err := meta.addDmChannel(defaultCollectionID, nodeID, dmChannels)
assert.Nil(t, err)
})
t.Run("Test GetDmChannelsByNodeID", func(t *testing.T) {
channels, err := meta.getDmChannelsByNodeID(defaultCollectionID, nodeID)
assert.Nil(t, err)
assert.Equal(t, 2, len(channels))
})
t.Run("Test SetSegmentInfo", func(t *testing.T) {
info := &querypb.SegmentInfo{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
NodeID: nodeID,
}
err := meta.setSegmentInfo(defaultSegmentID, info)
assert.Nil(t, err)
})
t.Run("Test ShowSegmentInfo", func(t *testing.T) {
infos := meta.showSegmentInfos(defaultCollectionID, []UniqueID{defaultPartitionID})
assert.Equal(t, 1, len(infos))
assert.Equal(t, defaultSegmentID, infos[0].SegmentID)
})
t.Run("Test GetQueryChannel", func(t *testing.T) {
reqChannel, resChannel, err := meta.GetQueryChannel(defaultCollectionID)
assert.NotNil(t, reqChannel)
assert.NotNil(t, resChannel)
assert.Nil(t, err)
})
t.Run("Test GetSegmentInfoByID", func(t *testing.T) {
info, err := meta.getSegmentInfoByID(defaultSegmentID)
assert.Nil(t, err)
assert.Equal(t, defaultSegmentID, info.SegmentID)
})
t.Run("Test SetLoadType", func(t *testing.T) {
err := meta.setLoadType(defaultCollectionID, querypb.LoadType_loadCollection)
assert.Nil(t, err)
})
t.Run("Test SetLoadPercentage", func(t *testing.T) {
err := meta.setLoadPercentage(defaultCollectionID, defaultPartitionID, 100, querypb.LoadType_LoadPartition)
assert.Nil(t, err)
state, err := meta.getPartitionStatesByID(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
assert.Equal(t, int64(100), state.InMemoryPercentage)
err = meta.setLoadPercentage(defaultCollectionID, defaultPartitionID, 100, querypb.LoadType_loadCollection)
assert.Nil(t, err)
info, err := meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(100), info.InMemoryPercentage)
})
t.Run("Test RemoveDmChannel", func(t *testing.T) {
err := meta.removeDmChannel(defaultCollectionID, nodeID, dmChannels)
assert.Nil(t, err)
channels, err := meta.getDmChannelsByNodeID(defaultCollectionID, nodeID)
assert.Nil(t, err)
assert.Equal(t, 0, len(channels))
})
t.Run("Test DeleteSegmentInfoByNodeID", func(t *testing.T) {
err := meta.deleteSegmentInfoByNodeID(nodeID)
assert.Nil(t, err)
_, err = meta.getSegmentInfoByID(defaultSegmentID)
assert.NotNil(t, err)
})
t.Run("Test ReleasePartition", func(t *testing.T) {
err := meta.releasePartition(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
})
t.Run("Test ReleaseCollection", func(t *testing.T) {
err := meta.releaseCollection(defaultCollectionID)
assert.Nil(t, err)
})
}
func TestReloadMetaFromKV(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)

View File

@ -302,7 +302,12 @@ func (lct *LoadCollectionTask) Execute(ctx context.Context) error {
}
}
assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs)
err = assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs)
if err != nil {
status.Reason = err.Error()
lct.result = status
return err
}
log.Debug("loadCollectionTask: assign child task done", zap.Int64("collectionID", collectionID))
log.Debug("LoadCollection execute done",
@ -581,7 +586,12 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error {
log.Debug("LoadPartitionTask: set watchDmChannelsRequests", zap.Any("request", watchDmRequest), zap.Int64("collectionID", collectionID))
}
}
assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs)
err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs)
if err != nil {
status.Reason = err.Error()
lpt.result = status
return err
}
log.Debug("LoadPartitionTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
log.Debug("LoadPartitionTask Execute done",
@ -878,7 +888,10 @@ func (lst *LoadSegmentTask) Reschedule() ([]task, error) {
hasWatchQueryChannel := lst.cluster.hasWatchedQueryChannel(lst.ctx, nodeID, collectionID)
if !hasWatchQueryChannel {
queryChannel, queryResultChannel := lst.meta.GetQueryChannel(collectionID)
queryChannel, queryResultChannel, err := lst.meta.GetQueryChannel(collectionID)
if err != nil {
return nil, err
}
msgBase := proto.Clone(lst.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
@ -1089,7 +1102,10 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) {
hasWatchQueryChannel := wdt.cluster.hasWatchedQueryChannel(wdt.ctx, nodeID, collectionID)
if !hasWatchQueryChannel {
queryChannel, queryResultChannel := wdt.meta.GetQueryChannel(collectionID)
queryChannel, queryResultChannel, err := wdt.meta.GetQueryChannel(collectionID)
if err != nil {
return nil, err
}
msgBase := proto.Clone(wdt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
@ -1348,7 +1364,12 @@ func (lbt *LoadBalanceTask) Execute(ctx context.Context) error {
}
}
}
assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs)
err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs)
if err != nil {
status.Reason = err.Error()
lbt.result = status
return err
}
log.Debug("loadBalanceTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
}
}
@ -1529,7 +1550,7 @@ func assignInternalTask(ctx context.Context,
meta Meta,
cluster Cluster,
loadSegmentRequests []*querypb.LoadSegmentsRequest,
watchDmChannelRequests []*querypb.WatchDmChannelsRequest) {
watchDmChannelRequests []*querypb.WatchDmChannelsRequest) error {
sp, _ := trace.StartSpanFromContext(ctx)
defer sp.Finish()
@ -1607,7 +1628,10 @@ func assignInternalTask(ctx context.Context,
for nodeID, watched := range watchQueryChannelInfo {
if !watched {
ctx = opentracing.ContextWithSpan(context.Background(), sp)
queryChannel, queryResultChannel := meta.GetQueryChannel(collectionID)
queryChannel, queryResultChannel, err := meta.GetQueryChannel(collectionID)
if err != nil {
return err
}
msgBase := proto.Clone(parentTask.MsgBase()).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
@ -1632,4 +1656,6 @@ func assignInternalTask(ctx context.Context,
log.Debug("assignInternalTask: add a watchQueryChannelTask childTask", zap.Any("task", watchQueryChannelTask))
}
}
return nil
}