Add ut for querycoord's meta (#7968)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/8542/head
xige-16 2021-09-29 09:56:04 +08:00 committed by GitHub
parent d6f96ec9fc
commit 9400694867
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 340 additions and 44 deletions

View File

@ -165,26 +165,28 @@ func TestGrpcRequest(t *testing.T) {
}) })
t.Run("Test AddQueryChannel", func(t *testing.T) { t.Run("Test AddQueryChannel", func(t *testing.T) {
reqChannel, resChannel := cluster.clusterMeta.GetQueryChannel(defaultCollectionID) reqChannel, resChannel, err := cluster.clusterMeta.GetQueryChannel(defaultCollectionID)
assert.Nil(t, err)
addQueryChannelReq := &querypb.AddQueryChannelRequest{ addQueryChannelReq := &querypb.AddQueryChannelRequest{
NodeID: nodeID, NodeID: nodeID,
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
RequestChannelID: reqChannel, RequestChannelID: reqChannel,
ResultChannelID: resChannel, ResultChannelID: resChannel,
} }
err := cluster.addQueryChannel(baseCtx, nodeID, addQueryChannelReq) err = cluster.addQueryChannel(baseCtx, nodeID, addQueryChannelReq)
assert.Nil(t, err) assert.Nil(t, err)
}) })
t.Run("Test RemoveQueryChannel", func(t *testing.T) { t.Run("Test RemoveQueryChannel", func(t *testing.T) {
reqChannel, resChannel := cluster.clusterMeta.GetQueryChannel(defaultCollectionID) reqChannel, resChannel, err := cluster.clusterMeta.GetQueryChannel(defaultCollectionID)
assert.Nil(t, err)
removeQueryChannelReq := &querypb.RemoveQueryChannelRequest{ removeQueryChannelReq := &querypb.RemoveQueryChannelRequest{
NodeID: nodeID, NodeID: nodeID,
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
RequestChannelID: reqChannel, RequestChannelID: reqChannel,
ResultChannelID: resChannel, ResultChannelID: resChannel,
} }
err := cluster.removeQueryChannel(baseCtx, nodeID, removeQueryChannelReq) err = cluster.removeQueryChannel(baseCtx, nodeID, removeQueryChannelReq)
assert.Nil(t, err) assert.Nil(t, err)
}) })

View File

@ -429,7 +429,15 @@ func (qc *QueryCoord) CreateQueryChannel(ctx context.Context, req *querypb.Creat
} }
collectionID := req.CollectionID collectionID := req.CollectionID
queryChannel, queryResultChannel := qc.meta.GetQueryChannel(collectionID) queryChannel, queryResultChannel, err := qc.meta.GetQueryChannel(collectionID)
if err != nil {
status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error()
log.Debug("createQueryChannel end with error")
return &querypb.CreateQueryChannelResponse{
Status: status,
}, err
}
return &querypb.CreateQueryChannelResponse{ return &querypb.CreateQueryChannelResponse{
Status: status, Status: status,

View File

@ -553,3 +553,22 @@ func TestGrpcTaskBeforeHealthy(t *testing.T) {
err = removeAllSession() err = removeAllSession()
assert.Nil(t, err) assert.Nil(t, err)
} }
func Test_GrpcGetQueryChannelFail(t *testing.T) {
kv := &testKv{
returnFn: failedResult,
}
meta, err := newMeta(kv)
assert.Nil(t, err)
queryCoord := &QueryCoord{
meta: meta,
}
queryCoord.stateCode.Store(internalpb.StateCode_Healthy)
res, err := queryCoord.CreateQueryChannel(context.Background(), &querypb.CreateQueryChannelRequest{
CollectionID: defaultCollectionID,
})
assert.NotNil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, res.Status.ErrorCode)
}

View File

@ -21,7 +21,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"go.uber.org/zap" "go.uber.org/zap"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/kv"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/schemapb"
@ -63,16 +63,16 @@ type Meta interface {
removeDmChannel(collectionID UniqueID, nodeID int64, channels []string) error removeDmChannel(collectionID UniqueID, nodeID int64, channels []string) error
getQueryChannelInfoByID(collectionID UniqueID) (*querypb.QueryChannelInfo, error) getQueryChannelInfoByID(collectionID UniqueID) (*querypb.QueryChannelInfo, error)
GetQueryChannel(collectionID UniqueID) (string, string) GetQueryChannel(collectionID UniqueID) (string, string, error)
setLoadType(collectionID UniqueID, loadType querypb.LoadType) error setLoadType(collectionID UniqueID, loadType querypb.LoadType) error
getLoadType(collectionID UniqueID) (querypb.LoadType, error) getLoadType(collectionID UniqueID) (querypb.LoadType, error)
setLoadPercentage(collectionID UniqueID, partitionID UniqueID, percentage int64, loadType querypb.LoadType) error setLoadPercentage(collectionID UniqueID, partitionID UniqueID, percentage int64, loadType querypb.LoadType) error
printMeta() //printMeta()
} }
type MetaReplica struct { type MetaReplica struct {
client *etcdkv.EtcdKV // client of a reliable kv service, i.e. etcd client client kv.MetaKv // client of a reliable kv service, i.e. etcd client
sync.RWMutex sync.RWMutex
collectionInfos map[UniqueID]*querypb.CollectionInfo collectionInfos map[UniqueID]*querypb.CollectionInfo
@ -82,7 +82,7 @@ type MetaReplica struct {
//partitionStates map[UniqueID]*querypb.PartitionStates //partitionStates map[UniqueID]*querypb.PartitionStates
} }
func newMeta(kv *etcdkv.EtcdKV) (Meta, error) { func newMeta(kv kv.MetaKv) (Meta, error) {
collectionInfos := make(map[UniqueID]*querypb.CollectionInfo) collectionInfos := make(map[UniqueID]*querypb.CollectionInfo)
segmentInfos := make(map[UniqueID]*querypb.SegmentInfo) segmentInfos := make(map[UniqueID]*querypb.SegmentInfo)
queryChannelInfos := make(map[UniqueID]*querypb.QueryChannelInfo) queryChannelInfos := make(map[UniqueID]*querypb.QueryChannelInfo)
@ -579,14 +579,14 @@ func (m *MetaReplica) removeDmChannel(collectionID UniqueID, nodeID int64, chann
return errors.New("addDmChannels: can't find collection in collectionInfos") return errors.New("addDmChannels: can't find collection in collectionInfos")
} }
func (m *MetaReplica) GetQueryChannel(collectionID UniqueID) (string, string) { func (m *MetaReplica) GetQueryChannel(collectionID UniqueID) (string, string, error) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
//TODO::to remove //TODO::to remove
collectionID = 0 collectionID = 0
if info, ok := m.queryChannelInfos[collectionID]; ok { if info, ok := m.queryChannelInfos[collectionID]; ok {
return info.QueryChannelID, info.QueryResultChannelID return info.QueryChannelID, info.QueryResultChannelID, nil
} }
searchPrefix := Params.SearchChannelPrefix searchPrefix := Params.SearchChannelPrefix
@ -600,9 +600,14 @@ func (m *MetaReplica) GetQueryChannel(collectionID UniqueID) (string, string) {
QueryChannelID: allocatedQueryChannel, QueryChannelID: allocatedQueryChannel,
QueryResultChannelID: allocatedQueryResultChannel, QueryResultChannelID: allocatedQueryResultChannel,
} }
err := saveQueryChannelInfo(collectionID, queryChannelInfo, m.client)
if err != nil {
log.Error("GetQueryChannel: save channel to etcd error", zap.Error(err))
return "", "", err
}
m.queryChannelInfos[collectionID] = queryChannelInfo m.queryChannelInfos[collectionID] = queryChannelInfo
//TODO::return channel according collectionID //TODO::return channel according collectionID
return allocatedQueryChannel, allocatedQueryResultChannel return allocatedQueryChannel, allocatedQueryResultChannel, nil
} }
func (m *MetaReplica) setLoadType(collectionID UniqueID, loadType querypb.LoadType) error { func (m *MetaReplica) setLoadType(collectionID UniqueID, loadType querypb.LoadType) error {
@ -680,54 +685,54 @@ func (m *MetaReplica) setLoadPercentage(collectionID UniqueID, partitionID Uniqu
return nil return nil
} }
func (m *MetaReplica) printMeta() { //func (m *MetaReplica) printMeta() {
m.RLock() // m.RLock()
defer m.RUnlock() // defer m.RUnlock()
for id, info := range m.collectionInfos { // for id, info := range m.collectionInfos {
log.Debug("query coordinator MetaReplica: collectionInfo", zap.Int64("collectionID", id), zap.Any("info", info)) // log.Debug("query coordinator MetaReplica: collectionInfo", zap.Int64("collectionID", id), zap.Any("info", info))
} // }
//
// for id, info := range m.segmentInfos {
// log.Debug("query coordinator MetaReplica: segmentInfo", zap.Int64("segmentID", id), zap.Any("info", info))
// }
//
// for id, info := range m.queryChannelInfos {
// log.Debug("query coordinator MetaReplica: queryChannelInfo", zap.Int64("collectionID", id), zap.Any("info", info))
// }
//}
for id, info := range m.segmentInfos { func saveGlobalCollectionInfo(collectionID UniqueID, info *querypb.CollectionInfo, kv kv.MetaKv) error {
log.Debug("query coordinator MetaReplica: segmentInfo", zap.Int64("segmentID", id), zap.Any("info", info))
}
for id, info := range m.queryChannelInfos {
log.Debug("query coordinator MetaReplica: queryChannelInfo", zap.Int64("collectionID", id), zap.Any("info", info))
}
}
func saveGlobalCollectionInfo(collectionID UniqueID, info *querypb.CollectionInfo, kv *etcdkv.EtcdKV) error {
infoBytes := proto.MarshalTextString(info) infoBytes := proto.MarshalTextString(info)
key := fmt.Sprintf("%s/%d", collectionMetaPrefix, collectionID) key := fmt.Sprintf("%s/%d", collectionMetaPrefix, collectionID)
return kv.Save(key, infoBytes) return kv.Save(key, infoBytes)
} }
func removeGlobalCollectionInfo(collectionID UniqueID, kv *etcdkv.EtcdKV) error { func removeGlobalCollectionInfo(collectionID UniqueID, kv kv.MetaKv) error {
key := fmt.Sprintf("%s/%d", collectionMetaPrefix, collectionID) key := fmt.Sprintf("%s/%d", collectionMetaPrefix, collectionID)
return kv.Remove(key) return kv.Remove(key)
} }
func saveSegmentInfo(segmentID UniqueID, info *querypb.SegmentInfo, kv *etcdkv.EtcdKV) error { func saveSegmentInfo(segmentID UniqueID, info *querypb.SegmentInfo, kv kv.MetaKv) error {
infoBytes := proto.MarshalTextString(info) infoBytes := proto.MarshalTextString(info)
key := fmt.Sprintf("%s/%d", segmentMetaPrefix, segmentID) key := fmt.Sprintf("%s/%d", segmentMetaPrefix, segmentID)
return kv.Save(key, infoBytes) return kv.Save(key, infoBytes)
} }
func removeSegmentInfo(segmentID UniqueID, kv *etcdkv.EtcdKV) error { func removeSegmentInfo(segmentID UniqueID, kv kv.MetaKv) error {
key := fmt.Sprintf("%s/%d", segmentMetaPrefix, segmentID) key := fmt.Sprintf("%s/%d", segmentMetaPrefix, segmentID)
return kv.Remove(key) return kv.Remove(key)
} }
func saveQueryChannelInfo(collectionID UniqueID, info *querypb.QueryChannelInfo, kv *etcdkv.EtcdKV) error { func saveQueryChannelInfo(collectionID UniqueID, info *querypb.QueryChannelInfo, kv kv.MetaKv) error {
infoBytes := proto.MarshalTextString(info) infoBytes := proto.MarshalTextString(info)
key := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, collectionID) key := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, collectionID)
return kv.Save(key, infoBytes) return kv.Save(key, infoBytes)
} }
func removeQueryChannelInfo(collectionID UniqueID, kv *etcdkv.EtcdKV) error { //func removeQueryChannelInfo(collectionID UniqueID, kv *etcdkv.EtcdKV) error {
key := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, collectionID) // key := fmt.Sprintf("%s/%d", queryChannelMetaPrefix, collectionID)
return kv.Remove(key) // return kv.Remove(key)
} //}

View File

@ -12,6 +12,7 @@
package querycoord package querycoord
import ( import (
"errors"
"fmt" "fmt"
"testing" "testing"
@ -19,10 +20,31 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
) )
func successResult() error { return nil }
func failedResult() error { return errors.New("") }
type testKv struct {
kv.MetaKv
returnFn func() error
}
func (tk *testKv) Save(key, value string) error {
return tk.returnFn()
}
func (tk *testKv) Remove(key string) error {
return tk.returnFn()
}
func (tk *testKv) LoadWithPrefix(key string) ([]string, []string, error) {
return nil, nil, nil
}
func TestReplica_Release(t *testing.T) { func TestReplica_Release(t *testing.T) {
refreshParams() refreshParams()
etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
@ -53,6 +75,220 @@ func TestReplica_Release(t *testing.T) {
meta.releaseCollection(1) meta.releaseCollection(1)
} }
func TestMetaFunc(t *testing.T) {
refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
assert.Nil(t, err)
meta := &MetaReplica{
client: kv,
collectionInfos: map[UniqueID]*querypb.CollectionInfo{},
segmentInfos: map[UniqueID]*querypb.SegmentInfo{},
queryChannelInfos: map[UniqueID]*querypb.QueryChannelInfo{},
}
nodeID := int64(100)
dmChannels := []string{"testDm1", "testDm2"}
t.Run("Test ShowPartitionFail", func(t *testing.T) {
res, err := meta.showPartitions(defaultCollectionID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test HasCollectionFalse", func(t *testing.T) {
hasCollection := meta.hasCollection(defaultCollectionID)
assert.Equal(t, false, hasCollection)
})
t.Run("Test HasPartitionFalse", func(t *testing.T) {
hasPartition := meta.hasPartition(defaultCollectionID, defaultPartitionID)
assert.Equal(t, false, hasPartition)
})
t.Run("Test HasReleasePartitionFalse", func(t *testing.T) {
hasReleasePartition := meta.hasReleasePartition(defaultCollectionID, defaultPartitionID)
assert.Equal(t, false, hasReleasePartition)
})
t.Run("Test HasSegmentInfoFalse", func(t *testing.T) {
hasSegmentInfo := meta.hasSegmentInfo(defaultSegmentID)
assert.Equal(t, false, hasSegmentInfo)
})
t.Run("Test GetSegmentInfoByIDFail", func(t *testing.T) {
res, err := meta.getSegmentInfoByID(defaultSegmentID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test GetCollectionInfoByIDFail", func(t *testing.T) {
res, err := meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, res)
assert.NotNil(t, err)
})
t.Run("Test GetQueryChannelInfoByIDFail", func(t *testing.T) {
res, err := meta.getQueryChannelInfoByID(defaultCollectionID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test GetPartitionStatesByIDFail", func(t *testing.T) {
res, err := meta.getPartitionStatesByID(defaultCollectionID, defaultPartitionID)
assert.Nil(t, res)
assert.NotNil(t, err)
})
t.Run("Test GetDmChannelsByNodeIDFail", func(t *testing.T) {
res, err := meta.getDmChannelsByNodeID(defaultCollectionID, nodeID)
assert.NotNil(t, err)
assert.Nil(t, res)
})
t.Run("Test AddDmChannelFail", func(t *testing.T) {
err := meta.addDmChannel(defaultCollectionID, nodeID, dmChannels)
assert.NotNil(t, err)
})
t.Run("Test SetLoadTypeFail", func(t *testing.T) {
err := meta.setLoadType(defaultCollectionID, querypb.LoadType_loadCollection)
assert.NotNil(t, err)
})
t.Run("Test SetLoadPercentageFail", func(t *testing.T) {
err := meta.setLoadPercentage(defaultCollectionID, defaultPartitionID, 100, querypb.LoadType_loadCollection)
assert.NotNil(t, err)
})
t.Run("Test AddCollection", func(t *testing.T) {
schema := genCollectionSchema(defaultCollectionID, false)
err := meta.addCollection(defaultCollectionID, schema)
assert.Nil(t, err)
})
t.Run("Test HasCollection", func(t *testing.T) {
hasCollection := meta.hasCollection(defaultCollectionID)
assert.Equal(t, true, hasCollection)
})
t.Run("Test AddPartition", func(t *testing.T) {
err := meta.addPartition(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
})
t.Run("Test HasPartition", func(t *testing.T) {
hasPartition := meta.hasPartition(defaultCollectionID, defaultPartitionID)
assert.Equal(t, true, hasPartition)
})
t.Run("Test ShowCollections", func(t *testing.T) {
info := meta.showCollections()
assert.Equal(t, 1, len(info))
})
t.Run("Test ShowPartitions", func(t *testing.T) {
states, err := meta.showPartitions(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, 1, len(states))
})
t.Run("Test GetCollectionInfoByID", func(t *testing.T) {
info, err := meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, defaultCollectionID, info.CollectionID)
})
t.Run("Test GetPartitionStatesByID", func(t *testing.T) {
state, err := meta.getPartitionStatesByID(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
assert.Equal(t, defaultPartitionID, state.PartitionID)
})
t.Run("Test AddDmChannel", func(t *testing.T) {
err := meta.addDmChannel(defaultCollectionID, nodeID, dmChannels)
assert.Nil(t, err)
})
t.Run("Test GetDmChannelsByNodeID", func(t *testing.T) {
channels, err := meta.getDmChannelsByNodeID(defaultCollectionID, nodeID)
assert.Nil(t, err)
assert.Equal(t, 2, len(channels))
})
t.Run("Test SetSegmentInfo", func(t *testing.T) {
info := &querypb.SegmentInfo{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
NodeID: nodeID,
}
err := meta.setSegmentInfo(defaultSegmentID, info)
assert.Nil(t, err)
})
t.Run("Test ShowSegmentInfo", func(t *testing.T) {
infos := meta.showSegmentInfos(defaultCollectionID, []UniqueID{defaultPartitionID})
assert.Equal(t, 1, len(infos))
assert.Equal(t, defaultSegmentID, infos[0].SegmentID)
})
t.Run("Test GetQueryChannel", func(t *testing.T) {
reqChannel, resChannel, err := meta.GetQueryChannel(defaultCollectionID)
assert.NotNil(t, reqChannel)
assert.NotNil(t, resChannel)
assert.Nil(t, err)
})
t.Run("Test GetSegmentInfoByID", func(t *testing.T) {
info, err := meta.getSegmentInfoByID(defaultSegmentID)
assert.Nil(t, err)
assert.Equal(t, defaultSegmentID, info.SegmentID)
})
t.Run("Test SetLoadType", func(t *testing.T) {
err := meta.setLoadType(defaultCollectionID, querypb.LoadType_loadCollection)
assert.Nil(t, err)
})
t.Run("Test SetLoadPercentage", func(t *testing.T) {
err := meta.setLoadPercentage(defaultCollectionID, defaultPartitionID, 100, querypb.LoadType_LoadPartition)
assert.Nil(t, err)
state, err := meta.getPartitionStatesByID(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
assert.Equal(t, int64(100), state.InMemoryPercentage)
err = meta.setLoadPercentage(defaultCollectionID, defaultPartitionID, 100, querypb.LoadType_loadCollection)
assert.Nil(t, err)
info, err := meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(100), info.InMemoryPercentage)
})
t.Run("Test RemoveDmChannel", func(t *testing.T) {
err := meta.removeDmChannel(defaultCollectionID, nodeID, dmChannels)
assert.Nil(t, err)
channels, err := meta.getDmChannelsByNodeID(defaultCollectionID, nodeID)
assert.Nil(t, err)
assert.Equal(t, 0, len(channels))
})
t.Run("Test DeleteSegmentInfoByNodeID", func(t *testing.T) {
err := meta.deleteSegmentInfoByNodeID(nodeID)
assert.Nil(t, err)
_, err = meta.getSegmentInfoByID(defaultSegmentID)
assert.NotNil(t, err)
})
t.Run("Test ReleasePartition", func(t *testing.T) {
err := meta.releasePartition(defaultCollectionID, defaultPartitionID)
assert.Nil(t, err)
})
t.Run("Test ReleaseCollection", func(t *testing.T) {
err := meta.releaseCollection(defaultCollectionID)
assert.Nil(t, err)
})
}
func TestReloadMetaFromKV(t *testing.T) { func TestReloadMetaFromKV(t *testing.T) {
refreshParams() refreshParams()
kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)

View File

@ -302,7 +302,12 @@ func (lct *LoadCollectionTask) Execute(ctx context.Context) error {
} }
} }
assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs) err = assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs)
if err != nil {
status.Reason = err.Error()
lct.result = status
return err
}
log.Debug("loadCollectionTask: assign child task done", zap.Int64("collectionID", collectionID)) log.Debug("loadCollectionTask: assign child task done", zap.Int64("collectionID", collectionID))
log.Debug("LoadCollection execute done", log.Debug("LoadCollection execute done",
@ -581,7 +586,12 @@ func (lpt *LoadPartitionTask) Execute(ctx context.Context) error {
log.Debug("LoadPartitionTask: set watchDmChannelsRequests", zap.Any("request", watchDmRequest), zap.Int64("collectionID", collectionID)) log.Debug("LoadPartitionTask: set watchDmChannelsRequests", zap.Any("request", watchDmRequest), zap.Int64("collectionID", collectionID))
} }
} }
assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs) err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs)
if err != nil {
status.Reason = err.Error()
lpt.result = status
return err
}
log.Debug("LoadPartitionTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) log.Debug("LoadPartitionTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
log.Debug("LoadPartitionTask Execute done", log.Debug("LoadPartitionTask Execute done",
@ -878,7 +888,10 @@ func (lst *LoadSegmentTask) Reschedule() ([]task, error) {
hasWatchQueryChannel := lst.cluster.hasWatchedQueryChannel(lst.ctx, nodeID, collectionID) hasWatchQueryChannel := lst.cluster.hasWatchedQueryChannel(lst.ctx, nodeID, collectionID)
if !hasWatchQueryChannel { if !hasWatchQueryChannel {
queryChannel, queryResultChannel := lst.meta.GetQueryChannel(collectionID) queryChannel, queryResultChannel, err := lst.meta.GetQueryChannel(collectionID)
if err != nil {
return nil, err
}
msgBase := proto.Clone(lst.Base).(*commonpb.MsgBase) msgBase := proto.Clone(lst.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
@ -1089,7 +1102,10 @@ func (wdt *WatchDmChannelTask) Reschedule() ([]task, error) {
hasWatchQueryChannel := wdt.cluster.hasWatchedQueryChannel(wdt.ctx, nodeID, collectionID) hasWatchQueryChannel := wdt.cluster.hasWatchedQueryChannel(wdt.ctx, nodeID, collectionID)
if !hasWatchQueryChannel { if !hasWatchQueryChannel {
queryChannel, queryResultChannel := wdt.meta.GetQueryChannel(collectionID) queryChannel, queryResultChannel, err := wdt.meta.GetQueryChannel(collectionID)
if err != nil {
return nil, err
}
msgBase := proto.Clone(wdt.Base).(*commonpb.MsgBase) msgBase := proto.Clone(wdt.Base).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
@ -1348,7 +1364,12 @@ func (lbt *LoadBalanceTask) Execute(ctx context.Context) error {
} }
} }
} }
assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs) err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs)
if err != nil {
status.Reason = err.Error()
lbt.result = status
return err
}
log.Debug("loadBalanceTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) log.Debug("loadBalanceTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs))
} }
} }
@ -1529,7 +1550,7 @@ func assignInternalTask(ctx context.Context,
meta Meta, meta Meta,
cluster Cluster, cluster Cluster,
loadSegmentRequests []*querypb.LoadSegmentsRequest, loadSegmentRequests []*querypb.LoadSegmentsRequest,
watchDmChannelRequests []*querypb.WatchDmChannelsRequest) { watchDmChannelRequests []*querypb.WatchDmChannelsRequest) error {
sp, _ := trace.StartSpanFromContext(ctx) sp, _ := trace.StartSpanFromContext(ctx)
defer sp.Finish() defer sp.Finish()
@ -1607,7 +1628,10 @@ func assignInternalTask(ctx context.Context,
for nodeID, watched := range watchQueryChannelInfo { for nodeID, watched := range watchQueryChannelInfo {
if !watched { if !watched {
ctx = opentracing.ContextWithSpan(context.Background(), sp) ctx = opentracing.ContextWithSpan(context.Background(), sp)
queryChannel, queryResultChannel := meta.GetQueryChannel(collectionID) queryChannel, queryResultChannel, err := meta.GetQueryChannel(collectionID)
if err != nil {
return err
}
msgBase := proto.Clone(parentTask.MsgBase()).(*commonpb.MsgBase) msgBase := proto.Clone(parentTask.MsgBase()).(*commonpb.MsgBase)
msgBase.MsgType = commonpb.MsgType_WatchQueryChannels msgBase.MsgType = commonpb.MsgType_WatchQueryChannels
@ -1632,4 +1656,6 @@ func assignInternalTask(ctx context.Context,
log.Debug("assignInternalTask: add a watchQueryChannelTask childTask", zap.Any("task", watchQueryChannelTask)) log.Debug("assignInternalTask: add a watchQueryChannelTask childTask", zap.Any("task", watchQueryChannelTask))
} }
} }
return nil
} }