mirror of https://github.com/milvus-io/milvus.git
Add test case for the workflow of query task (#7886)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/7893/head
parent
4bf3c6889c
commit
ec776a30dc
|
@ -2099,6 +2099,14 @@ func (qt *queryTask) getChannels() ([]pChan, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = qt.chMgr.getChannels(collID)
|
||||||
|
if err != nil {
|
||||||
|
err := qt.chMgr.createDMLMsgStream(collID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return qt.chMgr.getChannels(collID)
|
return qt.chMgr.getChannels(collID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2119,6 +2127,7 @@ func (qt *queryTask) getVChannels() ([]vChan, error) {
|
||||||
return qt.chMgr.getVChannels(collID)
|
return qt.chMgr.getVChannels(collID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* not used
|
||||||
func parseIdsFromExpr(exprStr string, schema *typeutil.SchemaHelper) ([]int64, error) {
|
func parseIdsFromExpr(exprStr string, schema *typeutil.SchemaHelper) ([]int64, error) {
|
||||||
expr, err := parseQueryExpr(schema, exprStr)
|
expr, err := parseQueryExpr(schema, exprStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2146,6 +2155,7 @@ func parseIdsFromExpr(exprStr string, schema *typeutil.SchemaHelper) ([]int64, e
|
||||||
return nil, errors.New("not top level term")
|
return nil, errors.New("not top level term")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
func IDs2Expr(fieldName string, ids []int64) string {
|
func IDs2Expr(fieldName string, ids []int64) string {
|
||||||
idsStr := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids)), ", "), "[]")
|
idsStr := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids)), ", "), "[]")
|
||||||
|
@ -2157,14 +2167,6 @@ func (qt *queryTask) PreExecute(ctx context.Context) error {
|
||||||
qt.Base.SourceID = Params.ProxyID
|
qt.Base.SourceID = Params.ProxyID
|
||||||
|
|
||||||
collectionName := qt.query.CollectionName
|
collectionName := qt.query.CollectionName
|
||||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
|
|
||||||
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Info("Get collection id by name.", zap.Any("collectionName", collectionName),
|
|
||||||
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
|
||||||
|
|
||||||
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
|
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
|
||||||
log.Debug("Invalid collection name.", zap.Any("collectionName", collectionName),
|
log.Debug("Invalid collection name.", zap.Any("collectionName", collectionName),
|
||||||
|
@ -2174,6 +2176,15 @@ func (qt *queryTask) PreExecute(ctx context.Context) error {
|
||||||
log.Info("Validate collection name.", zap.Any("collectionName", collectionName),
|
log.Info("Validate collection name.", zap.Any("collectionName", collectionName),
|
||||||
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
||||||
|
|
||||||
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName),
|
||||||
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Info("Get collection id by name.", zap.Any("collectionName", collectionName),
|
||||||
|
zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query"))
|
||||||
|
|
||||||
for _, tag := range qt.query.PartitionNames {
|
for _, tag := range qt.query.PartitionNames {
|
||||||
if err := ValidatePartitionTag(tag, false); err != nil {
|
if err := ValidatePartitionTag(tag, false); err != nil {
|
||||||
log.Debug("Invalid partition name.", zap.Any("partitionName", tag),
|
log.Debug("Invalid partition name.", zap.Any("partitionName", tag),
|
||||||
|
@ -2215,10 +2226,7 @@ func (qt *queryTask) PreExecute(ctx context.Context) error {
|
||||||
return fmt.Errorf("collection %v was not loaded into memory", collectionName)
|
return fmt.Errorf("collection %v was not loaded into memory", collectionName)
|
||||||
}
|
}
|
||||||
|
|
||||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName)
|
schema, _ := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName)
|
||||||
if err != nil { // err is not nil if collection not exists
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
// schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// return err
|
// return err
|
||||||
|
|
|
@ -2408,3 +2408,305 @@ func TestSearchTask_Execute(t *testing.T) {
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
// TODO(dragondriver): cover getDQLStream
|
// TODO(dragondriver): cover getDQLStream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryTask_all(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
Params.Init()
|
||||||
|
Params.RetrieveResultChannelNames = []string{funcutil.GenRandomStr()}
|
||||||
|
|
||||||
|
rc := NewRootCoordMock()
|
||||||
|
rc.Start()
|
||||||
|
defer rc.Stop()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = InitMetaCache(rc)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
shardsNum := int32(2)
|
||||||
|
prefix := "TestQueryTask_all"
|
||||||
|
dbName := ""
|
||||||
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
|
boolField := "bool"
|
||||||
|
int32Field := "int32"
|
||||||
|
int64Field := "int64"
|
||||||
|
floatField := "float"
|
||||||
|
doubleField := "double"
|
||||||
|
floatVecField := "fvec"
|
||||||
|
binaryVecField := "bvec"
|
||||||
|
fieldsLen := len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField})
|
||||||
|
dim := 128
|
||||||
|
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||||
|
hitNum := 10
|
||||||
|
|
||||||
|
schema := constructCollectionSchemaWithAllType(
|
||||||
|
boolField, int32Field, int64Field, floatField, doubleField,
|
||||||
|
floatVecField, binaryVecField, dim, collectionName)
|
||||||
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
createColT := &createCollectionTask{
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||||
|
Base: nil,
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Schema: marshaledSchema,
|
||||||
|
ShardsNum: shardsNum,
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
rootCoord: rc,
|
||||||
|
result: nil,
|
||||||
|
schema: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, createColT.OnEnqueue())
|
||||||
|
assert.NoError(t, createColT.PreExecute(ctx))
|
||||||
|
assert.NoError(t, createColT.Execute(ctx))
|
||||||
|
assert.NoError(t, createColT.PostExecute(ctx))
|
||||||
|
|
||||||
|
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||||
|
query := newMockGetChannelsService()
|
||||||
|
factory := newSimpleMockMsgStreamFactory()
|
||||||
|
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
|
||||||
|
defer chMgr.removeAllDMLStream()
|
||||||
|
defer chMgr.removeAllDQLStream()
|
||||||
|
|
||||||
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
qc := NewQueryCoordMock()
|
||||||
|
qc.Start()
|
||||||
|
defer qc.Stop()
|
||||||
|
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_LoadCollection,
|
||||||
|
MsgID: 0,
|
||||||
|
Timestamp: 0,
|
||||||
|
SourceID: Params.ProxyID,
|
||||||
|
},
|
||||||
|
DbID: 0,
|
||||||
|
CollectionID: collectionID,
|
||||||
|
Schema: nil,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||||
|
|
||||||
|
task := &queryTask{
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Retrieve,
|
||||||
|
MsgID: 0,
|
||||||
|
Timestamp: 0,
|
||||||
|
SourceID: Params.ProxyID,
|
||||||
|
},
|
||||||
|
ResultChannelID: strconv.Itoa(int(Params.ProxyID)),
|
||||||
|
DbID: 0,
|
||||||
|
CollectionID: collectionID,
|
||||||
|
PartitionIDs: nil,
|
||||||
|
SerializedExprPlan: nil,
|
||||||
|
OutputFieldsId: make([]int64, fieldsLen),
|
||||||
|
TravelTimestamp: 0,
|
||||||
|
GuaranteeTimestamp: 0,
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
resultBuf: make(chan []*internalpb.RetrieveResults),
|
||||||
|
result: &milvuspb.QueryResults{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
},
|
||||||
|
FieldsData: nil,
|
||||||
|
},
|
||||||
|
query: &milvuspb.QueryRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_Retrieve,
|
||||||
|
MsgID: 0,
|
||||||
|
Timestamp: 0,
|
||||||
|
SourceID: Params.ProxyID,
|
||||||
|
},
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Expr: expr,
|
||||||
|
OutputFields: nil,
|
||||||
|
PartitionNames: nil,
|
||||||
|
TravelTimestamp: 0,
|
||||||
|
GuaranteeTimestamp: 0,
|
||||||
|
},
|
||||||
|
chMgr: chMgr,
|
||||||
|
qc: qc,
|
||||||
|
ids: nil,
|
||||||
|
}
|
||||||
|
for i := 0; i < fieldsLen; i++ {
|
||||||
|
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// simple mock for query node
|
||||||
|
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
|
||||||
|
|
||||||
|
err = chMgr.createDQLStream(collectionID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
stream, err := chMgr.getDQLStream(collectionID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
consumeCtx, cancel := context.WithCancel(ctx)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-consumeCtx.Done():
|
||||||
|
return
|
||||||
|
case pack := <-stream.Chan():
|
||||||
|
for _, msg := range pack.Msgs {
|
||||||
|
_, ok := msg.(*msgstream.RetrieveMsg)
|
||||||
|
assert.True(t, ok)
|
||||||
|
// TODO(dragondriver): construct result according to the request
|
||||||
|
|
||||||
|
result1 := &internalpb.RetrieveResults{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_RetrieveResult,
|
||||||
|
MsgID: 0,
|
||||||
|
Timestamp: 0,
|
||||||
|
SourceID: 0,
|
||||||
|
},
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
Reason: "",
|
||||||
|
},
|
||||||
|
ResultChannelID: strconv.Itoa(int(Params.ProxyID)),
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: generateInt64Array(hitNum),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldsData: make([]*schemapb.FieldData, fieldsLen),
|
||||||
|
SealedSegmentIDsRetrieved: nil,
|
||||||
|
ChannelIDsRetrieved: nil,
|
||||||
|
GlobalSealedSegmentIDs: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1.FieldsData[0] = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Bool,
|
||||||
|
FieldName: boolField,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_BoolData{
|
||||||
|
BoolData: &schemapb.BoolArray{
|
||||||
|
Data: generateBoolArray(hitNum),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: common.StartOfUserFieldID + 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1.FieldsData[1] = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int32,
|
||||||
|
FieldName: int32Field,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_IntData{
|
||||||
|
IntData: &schemapb.IntArray{
|
||||||
|
Data: generateInt32Array(hitNum),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: common.StartOfUserFieldID + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1.FieldsData[2] = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
FieldName: int64Field,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{
|
||||||
|
Data: generateInt64Array(hitNum),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: common.StartOfUserFieldID + 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1.FieldsData[3] = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Float,
|
||||||
|
FieldName: floatField,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_FloatData{
|
||||||
|
FloatData: &schemapb.FloatArray{
|
||||||
|
Data: generateFloat32Array(hitNum),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: common.StartOfUserFieldID + 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1.FieldsData[4] = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Double,
|
||||||
|
FieldName: doubleField,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_DoubleData{
|
||||||
|
DoubleData: &schemapb.DoubleArray{
|
||||||
|
Data: generateFloat64Array(hitNum),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: common.StartOfUserFieldID + 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1.FieldsData[5] = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_FloatVector,
|
||||||
|
FieldName: doubleField,
|
||||||
|
Field: &schemapb.FieldData_Vectors{
|
||||||
|
Vectors: &schemapb.VectorField{
|
||||||
|
Dim: int64(dim),
|
||||||
|
Data: &schemapb.VectorField_FloatVector{
|
||||||
|
FloatVector: &schemapb.FloatArray{
|
||||||
|
Data: generateFloatVectors(hitNum, dim),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: common.StartOfUserFieldID + 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1.FieldsData[6] = &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_BinaryVector,
|
||||||
|
FieldName: doubleField,
|
||||||
|
Field: &schemapb.FieldData_Vectors{
|
||||||
|
Vectors: &schemapb.VectorField{
|
||||||
|
Dim: int64(dim),
|
||||||
|
Data: &schemapb.VectorField_BinaryVector{
|
||||||
|
BinaryVector: generateBinaryVectors(hitNum, dim),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldId: common.StartOfUserFieldID + 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
// send search result
|
||||||
|
task.resultBuf <- []*internalpb.RetrieveResults{result1}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
assert.NoError(t, task.OnEnqueue())
|
||||||
|
assert.NoError(t, task.PreExecute(ctx))
|
||||||
|
assert.NoError(t, task.Execute(ctx))
|
||||||
|
assert.NoError(t, task.PostExecute(ctx))
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue