Check if collection was loaded before search (#5853)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/5874/head
dragondriver 2021-06-18 16:32:07 +08:00 committed by GitHub
parent 78ffc95d70
commit ad23c47ff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 1 deletions

View File

@ -1151,6 +1151,7 @@ func (node *ProxyNode) Search(ctx context.Context, request *milvuspb.SearchReque
resultBuf: make(chan []*internalpb.SearchResults),
query: request,
chMgr: node.chMgr,
qs: node.queryService,
}
err := node.sched.DqQueue.Enqueue(qt)
@ -1222,6 +1223,7 @@ func (node *ProxyNode) Retrieve(ctx context.Context, request *milvuspb.RetrieveR
},
resultBuf: make(chan []*internalpb.RetrieveResults),
retrieve: request,
qs: node.queryService,
}
err := node.sched.DqQueue.Enqueue(rt)
@ -1373,6 +1375,7 @@ func (node *ProxyNode) Query(ctx context.Context, request *milvuspb.QueryRequest
resultBuf: make(chan []*internalpb.RetrieveResults),
retrieve: retrieveRequest,
chMgr: node.chMgr,
qs: node.queryService,
}
err := node.sched.DqQueue.Enqueue(rt)

View File

@ -993,6 +993,7 @@ type SearchTask struct {
result *milvuspb.SearchResults
query *milvuspb.SearchRequest
chMgr channelsMgr
qs types.QueryService
}
func (st *SearchTask) TraceCtx() context.Context {
@ -1063,7 +1064,7 @@ func (st *SearchTask) PreExecute(ctx context.Context) error {
st.Base.SourceID = Params.ProxyID
collectionName := st.query.CollectionName
_, err := globalMetaCache.GetCollectionID(ctx, collectionName)
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
if err != nil { // err is not nil if collection not exists
return err
}
@ -1077,6 +1078,39 @@ func (st *SearchTask) PreExecute(ctx context.Context) error {
return err
}
}
// check if collection was already loaded into query node
showResp, err := st.qs.ShowCollections(st.ctx, &querypb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: st.Base.MsgID,
Timestamp: st.Base.Timestamp,
SourceID: Params.ProxyID,
},
DbID: 0, // TODO(dragondriver)
})
if err != nil {
return err
}
if showResp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(showResp.Status.Reason)
}
log.Debug("query service show collections",
zap.Any("collections", showResp.CollectionIDs),
zap.Any("collID", collID))
collectionLoaded := false
for _, collectionID := range showResp.CollectionIDs {
if collectionID == collID {
collectionLoaded = true
break
}
}
if !collectionLoaded {
return fmt.Errorf("collection %v was not loaded into memory", collectionName)
}
// TODO(dragondriver): necessary to check if partition was loaded into query node?
st.Base.MsgType = commonpb.MsgType_Search
if st.query.GetDslType() == commonpb.DslType_BoolExprV1 {
@ -1491,6 +1525,7 @@ type RetrieveTask struct {
result *milvuspb.RetrieveResults
retrieve *milvuspb.RetrieveRequest
chMgr channelsMgr
qs types.QueryService
}
func (rt *RetrieveTask) TraceCtx() context.Context {
@ -1588,6 +1623,38 @@ func (rt *RetrieveTask) PreExecute(ctx context.Context) error {
log.Info("Validate partition names.",
zap.Any("requestID", rt.Base.MsgID), zap.Any("requestType", "retrieve"))
// check if collection was already loaded into query node
showResp, err := rt.qs.ShowCollections(rt.ctx, &querypb.ShowCollectionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: rt.Base.MsgID,
Timestamp: rt.Base.Timestamp,
SourceID: Params.ProxyID,
},
DbID: 0, // TODO(dragondriver)
})
if err != nil {
return err
}
if showResp.Status.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(showResp.Status.Reason)
}
log.Debug("query service show collections",
zap.Any("collections", showResp.CollectionIDs),
zap.Any("collID", collectionID))
collectionLoaded := false
for _, collID := range showResp.CollectionIDs {
if collectionID == collID {
collectionLoaded = true
break
}
}
if !collectionLoaded {
return fmt.Errorf("collection %v was not loaded into memory", collectionName)
}
// TODO(dragondriver): necessary to check if partition was loaded into query node?
rt.Base.MsgType = commonpb.MsgType_Retrieve
if rt.retrieve.Ids == nil {
errMsg := "Retrieve ids is nil"