Refactor invalid unit test (#22347)

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
pull/22043/head
aoiasd 2023-02-27 14:27:47 +08:00 committed by GitHub
parent b758c305a7
commit 872721e3ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 50 additions and 93 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@ -30,7 +31,6 @@ import (
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
@ -133,10 +133,8 @@ func TestSearchTask_PreExecute(t *testing.T) {
var (
rc = NewRootCoordMock()
qc = getQueryCoord()
qc = types.NewMockQueryCoord(t)
ctx = context.TODO()
collectionName = t.Name() + funcutil.GenRandomStr()
)
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
@ -147,14 +145,11 @@ func TestSearchTask_PreExecute(t *testing.T) {
err = InitMetaCache(ctx, rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()
defer qc.Stop()
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string) *searchTask {
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{},
ctx: ctx,
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
@ -166,12 +161,13 @@ func TestSearchTask_PreExecute(t *testing.T) {
return task
}
getSearchTaskWithNq := func(t *testing.T, nq int64) *searchTask {
getSearchTaskWithNq := func(t *testing.T, collName string, nq int64) *searchTask {
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{},
ctx: ctx,
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: "collection name",
CollectionName: collName,
Nq: nq,
},
qc: qc,
@ -181,37 +177,62 @@ func TestSearchTask_PreExecute(t *testing.T) {
return task
}
mockShowCollectionSuccess := func() *mock.Call {
return qc.On("ShowCollections", mock.Anything, mock.Anything).Return(
func(ctx context.Context, req *querypb.ShowCollectionsRequest) *querypb.ShowCollectionsResponse {
return &querypb.ShowCollectionsResponse{
Status: &successStatus,
CollectionIDs: req.CollectionIDs,
InMemoryPercentages: []int64{100},
}
}, nil)
}
mockShowCollectionFail := func() *mock.Call {
return qc.On("ShowCollections", mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil)
}
t.Run("bad nq 0", func(t *testing.T) {
call := mockShowCollectionSuccess()
defer call.Unset()
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
createColl(t, collName, rc)
// Nq must be in range [1, 16384].
task := getSearchTaskWithNq(t, 0)
task := getSearchTaskWithNq(t, collName, 0)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("bad nq 16385", func(t *testing.T) {
call := mockShowCollectionSuccess()
defer call.Unset()
collName := "test_bad_nq16385_error" + funcutil.GenRandomStr()
createColl(t, collName, rc)
// Nq must be in range [1, 16384].
task := getSearchTaskWithNq(t, 16384+1)
task := getSearchTaskWithNq(t, collName, 16384+1)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("collection not exist", func(t *testing.T) {
task := getSearchTask(t, collectionName)
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
task := getSearchTask(t, collName)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("invalid IgnoreGrowing param", func(t *testing.T) {
call := mockShowCollectionSuccess()
defer call.Unset()
collName := "test_invalid_param" + funcutil.GenRandomStr()
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
qc.EXPECT().ShowCollections(
mock.Anything, mock.MatchedBy(func(req *querypb.ShowCollectionsRequest) bool { return req.CollectionIDs[0] == collID })).Return(&querypb.ShowCollectionsResponse{
Status: &successStatus,
CollectionIDs: []int64{collID},
InMemoryPercentages: []int64{100},
}, nil).Times(1)
task := getSearchTask(t, collName)
task.request.SearchParams = getInvalidSearchParams(IgnoreGrowingKey)
@ -219,87 +240,23 @@ func TestSearchTask_PreExecute(t *testing.T) {
assert.Error(t, err)
})
t.Run("invalid collection name", func(t *testing.T) {
task := getSearchTask(t, collectionName)
createColl(t, collectionName, rc)
invalidCollNameTests := []struct {
inCollName string
description string
}{
{"$", "invalid collection name $"},
{"0", "invalid collection name 0"},
}
for _, test := range invalidCollNameTests {
t.Run(test.description, func(t *testing.T) {
task.request.CollectionName = test.inCollName
assert.Error(t, task.PreExecute(context.TODO()))
})
}
})
t.Run("invalid partition names", func(t *testing.T) {
task := getSearchTask(t, collectionName)
createColl(t, collectionName, rc)
invalidCollNameTests := []struct {
inPartNames []string
description string
}{
{[]string{"$"}, "invalid partition name $"},
{[]string{"0"}, "invalid collection name 0"},
{[]string{"default", "$"}, "invalid empty partition name"},
}
for _, test := range invalidCollNameTests {
t.Run(test.description, func(t *testing.T) {
task.request.PartitionNames = test.inPartNames
assert.Error(t, task.PreExecute(context.TODO()))
})
}
})
t.Run("test checkIfLoaded error", func(t *testing.T) {
collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr()
createColl(t, collName, rc)
_, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
task := getSearchTask(t, collName)
task.collectionName = collName
t.Run("show collection status unexpected error", func(t *testing.T) {
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil).Times(1)
call := mockShowCollectionFail()
defer call.Unset()
assert.Error(t, task.PreExecute(ctx))
})
})
t.Run("search with timeout", func(t *testing.T) {
call := mockShowCollectionSuccess()
defer call.Unset()
collName := "search_with_timeout" + funcutil.GenRandomStr()
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &successStatus,
CollectionIDs: []int64{collID},
InMemoryPercentages: []int64{100},
}, nil)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: collID,
})
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.GetErrorCode())
task := getSearchTask(t, collName)
task.request.SearchParams = getValidSearchParams()