milvus/internal/proxy/task_query_test.go

327 lines
9.4 KiB
Go

package proxy
import (
"context"
"fmt"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func TestQueryTask_all(t *testing.T) {
Params.Init()
var (
err error
ctx = context.TODO()
rc = NewRootCoordMock()
qc = NewQueryCoordMock(withValidShardLeaders())
qn = &QueryNodeMock{}
shardsNum = int32(2)
collectionName = t.Name() + funcutil.GenRandomStr()
expr = fmt.Sprintf("%s > 0", testInt64Field)
hitNum = 10
)
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) {
return qn, nil
}
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
rc.Start()
defer rc.Stop()
qc.Start()
defer qc.Stop()
err = InitMetaCache(rc, qc, mgr)
assert.NoError(t, err)
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testFloatVecField: schemapb.DataType_FloatVector,
}
if enableMultipleVectorFields {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
}
require.NoError(t, createColT.OnEnqueue())
require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionID: collectionID,
})
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
// test begins
task := &queryTask{
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionID: collectionID,
OutputFieldsId: make([]int64, len(fieldName2Types)),
},
ctx: ctx,
result: &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
SourceID: Params.ProxyCfg.GetNodeID(),
},
CollectionName: collectionName,
Expr: expr,
},
qc: qc,
queryShardPolicy: roundRobinPolicy,
shardMgr: mgr,
}
for i := 0; i < len(fieldName2Types); i++ {
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
}
assert.NoError(t, task.OnEnqueue())
// test query task with timeout
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
defer cancel1()
// before preExecute
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctx1
assert.NoError(t, task.PreExecute(ctx))
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: generateInt64Array(hitNum)},
},
},
}
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum))
}
qn.withQueryResult = result1
task.ctx = ctx
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
}
func TestCheckIfLoaded(t *testing.T) {
var err error
Params.Init()
var (
rc = NewRootCoordMock()
qc = NewQueryCoordMock()
ctx = context.TODO()
)
err = rc.Start()
defer rc.Stop()
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()
defer qc.Stop()
require.NoError(t, err)
getQueryTask := func(t *testing.T, collName string) *queryTask {
task := &queryTask{
ctx: ctx,
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
},
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
},
CollectionName: collName,
},
qc: qc,
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("test checkIfLoaded error", func(t *testing.T) {
collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr()
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
task := getQueryTask(t, collName)
task.collectionName = collName
t.Run("show collection err", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, fmt.Errorf("mock")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show collection status unexpected error", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
assert.Error(t, task.PreExecute(ctx))
qc.ResetShowCollectionsFunc()
})
t.Run("show partition error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partition status unexpected error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, fmt.Errorf("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partitions success", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.NoError(t, err)
assert.True(t, loaded)
qc.ResetShowPartitionsFunc()
})
t.Run("show collection success but not loaded", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{collID},
InMemoryPercentages: []int64{0},
}, nil
})
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, fmt.Errorf("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, fmt.Errorf("mock error")
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionIDs: []UniqueID{1},
}, nil
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.NoError(t, err)
assert.True(t, loaded)
})
qc.ResetShowCollectionsFunc()
qc.ResetShowPartitionsFunc()
})
}