Add msg id to log in query path (#17677)

/kind improvement

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
pull/17728/head
Jiquan Long 2022-06-23 10:46:13 +08:00 committed by GitHub
parent 1fbdafc943
commit 216e2f80aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 258 additions and 437 deletions

View File

@ -9,11 +9,13 @@ import (
type getCollectionIDFunc func(ctx context.Context, collectionName string) (typeutil.UniqueID, error)
type getCollectionSchemaFunc func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
type getCollectionInfoFunc func(ctx context.Context, collectionName string) (*collectionInfo, error)
type mockCache struct {
Cache
getIDFunc getCollectionIDFunc
getSchemaFunc getCollectionSchemaFunc
getInfoFunc getCollectionInfoFunc
}
func (m *mockCache) GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
@ -30,6 +32,13 @@ func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName stri
return nil, nil
}
func (m *mockCache) GetCollectionInfo(ctx context.Context, collectionName string) (*collectionInfo, error) {
if m.getInfoFunc != nil {
return m.getInfoFunc(ctx, collectionName)
}
return nil, nil
}
func (m *mockCache) RemoveCollection(ctx context.Context, collectionName string) {
}
@ -41,6 +50,10 @@ func (m *mockCache) setGetSchemaFunc(f getCollectionSchemaFunc) {
m.getSchemaFunc = f
}
func (m *mockCache) setGetInfoFunc(f getCollectionInfoFunc) {
m.getInfoFunc = f
}
func newMockCache() *mockCache {
return &mockCache{}
}

View File

@ -243,20 +243,10 @@ func (coord *QueryCoordMock) ResetShowPartitionsFunc() {
}
func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
if !coord.healthy() {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unhealthy",
},
}, nil
}
if coord.showPartitionsFunc != nil {
return coord.showPartitionsFunc(ctx, req)
}
panic("implement me")
return nil, nil
}
func (coord *QueryCoordMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
@ -56,6 +55,41 @@ type queryTask struct {
shardMgr *shardClientMgr
}
// translateOutputFields translates output fields name to output fields id.
func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) {
outputFieldIDs := make([]UniqueID, 0, len(outputFields))
if len(outputFields) == 0 {
for _, field := range schema.Fields {
if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
outputFieldIDs = append(outputFieldIDs, field.FieldID)
}
}
} else {
addPrimaryKey := false
for _, reqField := range outputFields {
findField := false
for _, field := range schema.Fields {
if reqField == field.Name {
if field.IsPrimaryKey {
addPrimaryKey = true
}
findField = true
outputFieldIDs = append(outputFieldIDs, field.FieldID)
} else {
if field.IsPrimaryKey && !addPrimaryKey {
outputFieldIDs = append(outputFieldIDs, field.FieldID)
addPrimaryKey = true
}
}
}
if !findField {
return nil, fmt.Errorf("field %s not exist", reqField)
}
}
}
return outputFieldIDs, nil
}
func (t *queryTask) PreExecute(ctx context.Context) error {
if t.queryShardPolicy == nil {
t.queryShardPolicy = roundRobinPolicy
@ -68,73 +102,46 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
t.collectionName = collectionName
if err := validateCollectionName(collectionName); err != nil {
log.Warn("Invalid collection name.", zap.String("collectionName", collectionName),
zap.Int64("requestID", t.Base.MsgID), zap.String("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.String("requestType", "query"))
return err
}
log.Info("Validate collection name.", zap.Any("collectionName", collectionName),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
if err != nil {
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err
}
t.CollectionID = collID
log.Info("Get collection ID by name",
zap.Int64("collectionID", t.CollectionID), zap.String("collection name", collectionName),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
for _, tag := range t.request.PartitionNames {
if err := validatePartitionTag(tag, false); err != nil {
log.Warn("invalid partition name", zap.String("partition name", tag),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err
}
}
log.Debug("Validate partition names.",
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.PartitionIDs = make([]UniqueID, 0)
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames())
if err != nil {
log.Warn("failed to get partitions in collection.", zap.String("collection name", collectionName),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Error(err),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err
}
log.Debug("Get partitions in collection.", zap.Any("collectionName", collectionName),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
// Check if partitions are valid partitions in collection
partitionsRecord := make(map[UniqueID]bool)
for _, partitionName := range t.request.PartitionNames {
pattern := fmt.Sprintf("^%s$", partitionName)
re, err := regexp.Compile(pattern)
if err != nil {
log.Debug("failed to compile partition name regex expression.", zap.Any("partition name", partitionName),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
return errors.New("invalid partition names")
}
found := false
for name, pID := range partitionsMap {
if re.MatchString(name) {
if _, exist := partitionsRecord[pID]; !exist {
t.PartitionIDs = append(t.PartitionIDs, pID)
partitionsRecord[pID] = true
}
found = true
}
}
if !found {
// FIXME(wxyu): undefined behavior
errMsg := fmt.Sprintf("partition name: %s not found", partitionName)
return errors.New(errMsg)
}
}
loaded, err := t.checkIfLoaded(collID, t.PartitionIDs)
loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.RetrieveRequest.GetPartitionIDs())
if err != nil {
return fmt.Errorf("checkIfLoaded failed when query, collection:%v, partitions:%v, err = %s", collectionName, t.request.GetPartitionNames(), err)
}
@ -167,41 +174,16 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
return err
}
log.Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
if len(t.request.OutputFields) == 0 {
for _, field := range schema.Fields {
if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector {
t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID)
}
}
} else {
addPrimaryKey := false
for _, reqField := range t.request.OutputFields {
findField := false
for _, field := range schema.Fields {
if reqField == field.Name {
if field.IsPrimaryKey {
addPrimaryKey = true
}
findField = true
t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID)
plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
} else {
if field.IsPrimaryKey && !addPrimaryKey {
t.OutputFieldsId = append(t.OutputFieldsId, field.FieldID)
plan.OutputFieldIds = append(plan.OutputFieldIds, field.FieldID)
addPrimaryKey = true
}
}
}
if !findField {
return fmt.Errorf("field %s not exist", reqField)
}
}
outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema)
if err != nil {
return err
}
t.RetrieveRequest.OutputFieldsId = outputFieldIDs
plan.OutputFieldIds = outputFieldIDs
log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId),
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan)
if err != nil {
@ -229,7 +211,9 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
t.DbID = 0 // TODO
log.Info("Query PreExecute done.",
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"),
zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("travel_ts", t.GetTravelTimestamp()),
zap.Uint64("timeout_ts", t.GetTimeoutTimestamp()))
return nil
}
@ -250,6 +234,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
leaders := leaders
t.runningGroup.Go(func() error {
log.Debug("proxy starting to query one shard",
zap.Int64("msgID", t.ID()),
zap.Int64("collectionID", t.CollectionID),
zap.String("collection name", t.collectionName),
zap.String("shard channel", channelID),
@ -269,7 +254,8 @@ func (t *queryTask) Execute(ctx context.Context) error {
err := executeQuery(WithCache)
if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) {
log.Warn("invalid shard leaders cache, updating shardleader caches and retry search")
log.Warn("invalid shard leaders cache, updating shardleader caches and retry search",
zap.Int64("msgID", t.ID()), zap.Error(err))
return executeQuery(WithoutCache)
}
if err != nil {
@ -277,7 +263,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
}
log.Info("Query Execute done.",
zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return nil
}
@ -294,14 +280,14 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
for {
select {
case <-t.TraceCtx().Done():
log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, taskID:", t.ID()))
log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID()))
return
case <-t.runningGroupCtx.Done():
log.Debug("all queries are finished or canceled", zap.Any("taskID", t.ID()))
log.Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID()))
close(t.resultBuf)
for res := range t.resultBuf {
t.toReduceResults = append(t.toReduceResults, res)
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("taskID", t.ID()))
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID()))
}
wg.Done()
return
@ -325,7 +311,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
ErrorCode: commonpb.ErrorCode_Success,
}
} else {
log.Info("Query result is nil", zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "query"))
log.Info("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.result.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_EmptyCollection,
Reason: "emptly collection", // TODO
@ -346,7 +332,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
}
}
}
log.Info("Query PostExecute done", zap.Any("requestID", t.Base.MsgID), zap.String("requestType", "query"))
log.Info("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query"))
return nil
}
@ -360,93 +346,37 @@ func (t *queryTask) queryShard(ctx context.Context, leaders []nodeInfo, channelI
result, err := qn.Query(ctx, req)
if err != nil {
log.Warn("QueryNode query return error", zap.Int64("nodeID", nodeID), zap.String("channel", channelID),
zap.Error(err))
log.Warn("QueryNode query return error", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.String("channel", channelID), zap.Error(err))
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.String("channel", channelID))
log.Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.String("channel", channelID))
return errInvalidShardLeaders
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode query result error", zap.Int64("nodeID", nodeID),
zap.String("reason", result.GetStatus().GetReason()))
log.Warn("QueryNode query result error", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
}
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
log.Debug("get query result", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
t.resultBuf <- result
return nil
}
err := t.queryShardPolicy(t.TraceCtx(), t.shardMgr, query, leaders)
if err != nil {
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
log.Warn("fail to Query to all shard leaders", zap.Int64("msgID", t.ID()),
zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
return err
}
return nil
}
func (t *queryTask) checkIfLoaded(collectionID UniqueID, queryPartitionIDs []UniqueID) (bool, error) {
// check if collection was loaded into QueryNode
info, err := globalMetaCache.GetCollectionInfo(t.ctx, t.collectionName)
if err != nil {
return false, fmt.Errorf("GetCollectionInfo failed, collectionID = %d, err = %s", collectionID, err)
}
if info.isLoaded {
return true, nil
}
// If request to query partitions
if len(queryPartitionIDs) > 0 {
resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: t.Base.MsgID,
Timestamp: t.Base.Timestamp,
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionID: collectionID,
PartitionIDs: queryPartitionIDs,
})
if err != nil {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, queryPartitionIDs, err)
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, queryPartitionIDs, resp.GetStatus().GetReason())
}
// Current logic: show partitions won't return error if the given partitions are all loaded
return true, nil
}
// If request to query collection and collection is not fully loaded
resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: t.Base.MsgID,
Timestamp: t.Base.Timestamp,
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionID: collectionID,
})
if err != nil {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, queryPartitionIDs, err)
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, queryPartitionIDs, resp.GetStatus().GetReason())
}
if len(resp.GetPartitionIDs()) > 0 {
log.Warn("collection not fully loaded, query on these partitions",
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", resp.GetPartitionIDs()))
return true, nil
}
return false, nil
}
// IDs2Expr converts ids slices to bool expresion with specified field name
func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
var idsStr string

View File

@ -6,11 +6,12 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
@ -129,9 +130,6 @@ func TestQueryTask_all(t *testing.T) {
queryShardPolicy: roundRobinPolicy,
shardMgr: mgr,
}
for i := 0; i < len(fieldName2Types); i++ {
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
}
assert.NoError(t, task.OnEnqueue())
@ -157,6 +155,11 @@ func TestQueryTask_all(t *testing.T) {
},
}
outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types))
for i := 0; i < len(fieldName2Types); i++ {
outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i))
}
task.RetrieveRequest.OutputFieldsId = outputFieldIDs
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum))
}
@ -168,159 +171,3 @@ func TestQueryTask_all(t *testing.T) {
assert.NoError(t, task.PostExecute(ctx))
}
func TestCheckIfLoaded(t *testing.T) {
var err error
Params.Init()
var (
rc = NewRootCoordMock()
qc = NewQueryCoordMock()
ctx = context.TODO()
)
err = rc.Start()
defer rc.Stop()
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()
defer qc.Stop()
require.NoError(t, err)
getQueryTask := func(t *testing.T, collName string) *queryTask {
task := &queryTask{
ctx: ctx,
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
},
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
CollectionName: collName,
},
qc: qc,
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("test checkIfLoaded error", func(t *testing.T) {
collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr()
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
task := getQueryTask(t, collName)
task.collectionName = collName
t.Run("show collection err", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, fmt.Errorf("mock")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show collection status unexpected error", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
assert.Error(t, task.PreExecute(ctx))
qc.ResetShowCollectionsFunc()
})
t.Run("show partition error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partition status unexpected error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, fmt.Errorf("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partitions success", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.NoError(t, err)
assert.True(t, loaded)
qc.ResetShowPartitionsFunc()
})
t.Run("show collection success but not loaded", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{collID},
InMemoryPercentages: []int64{0},
}, nil
})
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, fmt.Errorf("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, fmt.Errorf("mock error")
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionIDs: []UniqueID{1},
}, nil
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.NoError(t, err)
assert.True(t, loaded)
})
qc.ResetShowCollectionsFunc()
qc.ResetShowPartitionsFunc()
})
}

View File

@ -202,7 +202,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
// check if collection/partitions are loaded into query node
loaded, err := t.checkIfLoaded(collID, t.SearchRequest.GetPartitionIDs())
loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.SearchRequest.GetPartitionIDs())
if err != nil {
return fmt.Errorf("checkIfLoaded failed when search, collection:%v, partitions:%v, err = %s", collectionName, t.request.GetPartitionNames(), err)
}
@ -472,11 +472,11 @@ func (t *searchTask) searchShard(ctx context.Context, leaders []nodeInfo, channe
return nil
}
func (t *searchTask) checkIfLoaded(collectionID UniqueID, searchPartitionIDs []UniqueID) (bool, error) {
// check if collection was loaded into QueryNode
info, err := globalMetaCache.GetCollectionInfo(t.ctx, t.collectionName)
// checkIfLoaded check if collection was loaded into QueryNode
func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) (bool, error) {
info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName)
if err != nil {
return false, fmt.Errorf("GetCollectionInfo failed, collectionID = %d, err = %s", collectionID, err)
return false, fmt.Errorf("GetCollectionInfo failed, collection = %s, err = %s", collectionName, err)
}
if info.isLoaded {
return true, nil
@ -484,48 +484,45 @@ func (t *searchTask) checkIfLoaded(collectionID UniqueID, searchPartitionIDs []U
// If request to search partitions
if len(searchPartitionIDs) > 0 {
resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{
resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: t.Base.MsgID,
Timestamp: t.Base.Timestamp,
SourceID: Params.ProxyCfg.GetNodeID(),
MsgType: commonpb.MsgType_ShowCollections,
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionID: collectionID,
CollectionID: info.collID,
PartitionIDs: searchPartitionIDs,
})
if err != nil {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err)
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason())
}
// Current logic: show partitions won't return error if the given partitions are all loaded
return true, nil
}
// If request to search collection and collection is not fully loaded
resp, err := t.qc.ShowPartitions(t.ctx, &querypb.ShowPartitionsRequest{
resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowCollections,
MsgID: t.Base.MsgID,
Timestamp: t.Base.Timestamp,
SourceID: Params.ProxyCfg.GetNodeID(),
MsgType: commonpb.MsgType_ShowCollections,
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionID: collectionID,
CollectionID: info.collID,
})
if err != nil {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err)
return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, err = %s", collectionName, searchPartitionIDs, err)
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return false, fmt.Errorf("showPartitions failed, collectionID = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason())
return false, fmt.Errorf("showPartitions failed, collection = %s, partitionIDs = %v, reason = %s", collectionName, searchPartitionIDs, resp.GetStatus().GetReason())
}
if len(resp.GetPartitionIDs()) > 0 {
log.Warn("collection not fully loaded, search on these partitions", zap.Int64("msgID", t.ID()),
zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", resp.GetPartitionIDs()))
log.Warn("collection not fully loaded, search on these partitions",
zap.String("collection", collectionName),
zap.Int64("collectionID", info.collID), zap.Int64s("partitionIDs", resp.GetPartitionIDs()))
return true, nil
}

View File

@ -196,21 +196,11 @@ func TestSearchTask_PreExecute(t *testing.T) {
t.Run("test checkIfLoaded error", func(t *testing.T) {
collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr()
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
_, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
task := getSearchTask(t, collName)
task.collectionName = collName
t.Run("show collection err", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, errors.New("mock")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show collection status unexpected error", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
@ -221,88 +211,10 @@ func TestSearchTask_PreExecute(t *testing.T) {
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
assert.Error(t, task.PreExecute(ctx))
qc.ResetShowCollectionsFunc()
})
t.Run("show partition error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partition status unexpected error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partitions success", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.NoError(t, err)
assert.True(t, loaded)
qc.ResetShowPartitionsFunc()
})
t.Run("show collection success but not loaded", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{collID},
InMemoryPercentages: []int64{0},
}, nil
})
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock error")
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionIDs: []UniqueID{1},
}, nil
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.NoError(t, err)
assert.True(t, loaded)
})
qc.ResetShowCollectionsFunc()
qc.ResetShowPartitionsFunc()
})
@ -1553,3 +1465,129 @@ func Test_reduceSearchResultData_str(t *testing.T) {
// hard to compare floating point value.
// TODO: compare scores.
}
func Test_checkIfLoaded(t *testing.T) {
t.Run("failed to get collection info", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return nil, errors.New("mock")
})
globalMetaCache = cache
var qc types.QueryCoord
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.Error(t, err)
})
t.Run("collection loaded", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: true}, nil
})
globalMetaCache = cache
var qc types.QueryCoord
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.NoError(t, err)
assert.True(t, loaded)
})
t.Run("show partitions failed", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock")
})
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.Error(t, err)
})
t.Run("show partitions but didn't success", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil
})
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.Error(t, err)
})
t.Run("partitions loaded", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil
})
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{1, 2})
assert.NoError(t, err)
assert.True(t, loaded)
})
t.Run("no specified partitions, show partitions failed", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock")
})
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.Error(t, err)
})
t.Run("no specified partitions, show partitions but didn't succeed", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}}, nil
})
_, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.Error(t, err)
})
t.Run("not fully loaded", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{1, 2}}, nil
})
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.NoError(t, err)
assert.True(t, loaded)
})
t.Run("not loaded", func(t *testing.T) {
cache := newMockCache()
cache.setGetInfoFunc(func(ctx context.Context, collectionName string) (*collectionInfo, error) {
return &collectionInfo{isLoaded: false}, nil
})
globalMetaCache = cache
qc := NewQueryCoordMock()
qc.SetShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, PartitionIDs: []UniqueID{}}, nil
})
loaded, err := checkIfLoaded(context.Background(), qc, "test", []UniqueID{})
assert.NoError(t, err)
assert.False(t, loaded)
})
}

View File

@ -1505,7 +1505,7 @@ func genSimpleRetrievePlan(collection *Collection) (*RetrievePlan, error) {
}
timestamp := retrieveMsg.RetrieveRequest.TravelTimestamp
plan, err2 := createRetrievePlanByExpr(collection, retrieveMsg.SerializedExprPlan, timestamp)
plan, err2 := createRetrievePlanByExpr(collection, retrieveMsg.SerializedExprPlan, timestamp, 100)
return plan, err2
}

View File

@ -136,7 +136,7 @@ func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh
plan: plan,
cPlaceholderGroup: cPlaceholderGroup,
timestamp: req.Req.GetTravelTimestamp(),
msgID: req.GetReq().GetReqID(),
msgID: req.GetReq().GetBase().GetMsgID(),
}
return ret, nil
@ -175,9 +175,10 @@ func parseSearchRequest(plan *SearchPlan, searchRequestBlob []byte) (*searchRequ
type RetrievePlan struct {
cRetrievePlan C.CRetrievePlan
Timestamp Timestamp
msgID UniqueID // only used to debug.
}
func createRetrievePlanByExpr(col *Collection, expr []byte, timestamp Timestamp) (*RetrievePlan, error) {
func createRetrievePlanByExpr(col *Collection, expr []byte, timestamp Timestamp, msgID UniqueID) (*RetrievePlan, error) {
var cPlan C.CRetrievePlan
status := C.CreateRetrievePlanByExpr(col.collectionPtr, unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan)
@ -189,6 +190,7 @@ func createRetrievePlanByExpr(col *Collection, expr []byte, timestamp Timestamp)
var newPlan = &RetrievePlan{
cRetrievePlan: cPlan,
Timestamp: timestamp,
msgID: msgID,
}
return newPlan, nil
}

View File

@ -303,7 +303,9 @@ func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, erro
status := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, ts, &retrieveResult.cRetrieveResult)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("do retrieve on segment", zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String()))
log.Debug("do retrieve on segment",
zap.Int64("msgID", plan.msgID),
zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String()))
if err := HandleCStatus(&status, "Retrieve failed"); err != nil {
return nil, err
}

View File

@ -200,7 +200,7 @@ func TestSegment_retrieve(t *testing.T) {
// }
planExpr, err := proto.Marshal(planNode)
assert.NoError(t, err)
plan, err := createRetrievePlanByExpr(collection, planExpr, 100)
plan, err := createRetrievePlanByExpr(collection, planExpr, 100, 100)
defer plan.delete()
assert.NoError(t, err)

View File

@ -66,12 +66,13 @@ func (q *queryTask) queryOnStreaming() error {
q.QS.collection.RLock() // locks the collectionPtr
defer q.QS.collection.RUnlock()
if _, released := q.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("collectionID", q.CollectionID))
log.Debug("collection release before search", zap.Int64("msgID", q.ID()),
zap.Int64("collectionID", q.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID)
}
// deserialize query plan
plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp)
plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp, q.ID())
if err != nil {
return err
}
@ -113,12 +114,13 @@ func (q *queryTask) queryOnHistorical() error {
defer q.QS.collection.RUnlock()
if _, released := q.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("collectionID", q.CollectionID))
log.Debug("collection release before search", zap.Int64("msgID", q.ID()),
zap.Int64("collectionID", q.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID)
}
// deserialize query plan
plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp)
plan, err := createRetrievePlanByExpr(q.QS.collection, q.iReq.GetSerializedExprPlan(), q.TravelTimestamp, q.ID())
if err != nil {
return err
}