mirror of https://github.com/milvus-io/milvus.git
parent
b758c305a7
commit
872721e3ec
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -30,7 +31,6 @@ import (
|
||||||
"github.com/milvus-io/milvus/internal/util/distance"
|
"github.com/milvus-io/milvus/internal/util/distance"
|
||||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -133,10 +133,8 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
rc = NewRootCoordMock()
|
rc = NewRootCoordMock()
|
||||||
qc = getQueryCoord()
|
qc = types.NewMockQueryCoord(t)
|
||||||
ctx = context.TODO()
|
ctx = context.TODO()
|
||||||
|
|
||||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
|
||||||
)
|
)
|
||||||
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||||
|
|
||||||
|
@ -147,14 +145,11 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||||
err = InitMetaCache(ctx, rc, qc, mgr)
|
err = InitMetaCache(ctx, rc, qc, mgr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = qc.Start()
|
|
||||||
defer qc.Stop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
getSearchTask := func(t *testing.T, collName string) *searchTask {
|
getSearchTask := func(t *testing.T, collName string) *searchTask {
|
||||||
task := &searchTask{
|
task := &searchTask{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
SearchRequest: &internalpb.SearchRequest{},
|
collectionName: collName,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{},
|
||||||
request: &milvuspb.SearchRequest{
|
request: &milvuspb.SearchRequest{
|
||||||
CollectionName: collName,
|
CollectionName: collName,
|
||||||
Nq: 1,
|
Nq: 1,
|
||||||
|
@ -166,12 +161,13 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
|
||||||
getSearchTaskWithNq := func(t *testing.T, nq int64) *searchTask {
|
getSearchTaskWithNq := func(t *testing.T, collName string, nq int64) *searchTask {
|
||||||
task := &searchTask{
|
task := &searchTask{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
SearchRequest: &internalpb.SearchRequest{},
|
collectionName: collName,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{},
|
||||||
request: &milvuspb.SearchRequest{
|
request: &milvuspb.SearchRequest{
|
||||||
CollectionName: "collection name",
|
CollectionName: collName,
|
||||||
Nq: nq,
|
Nq: nq,
|
||||||
},
|
},
|
||||||
qc: qc,
|
qc: qc,
|
||||||
|
@ -181,37 +177,62 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mockShowCollectionSuccess := func() *mock.Call {
|
||||||
|
return qc.On("ShowCollections", mock.Anything, mock.Anything).Return(
|
||||||
|
func(ctx context.Context, req *querypb.ShowCollectionsRequest) *querypb.ShowCollectionsResponse {
|
||||||
|
return &querypb.ShowCollectionsResponse{
|
||||||
|
Status: &successStatus,
|
||||||
|
CollectionIDs: req.CollectionIDs,
|
||||||
|
InMemoryPercentages: []int64{100},
|
||||||
|
}
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockShowCollectionFail := func() *mock.Call {
|
||||||
|
return qc.On("ShowCollections", mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "mock",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("bad nq 0", func(t *testing.T) {
|
t.Run("bad nq 0", func(t *testing.T) {
|
||||||
|
call := mockShowCollectionSuccess()
|
||||||
|
defer call.Unset()
|
||||||
|
|
||||||
|
collName := "test_bad_nq0_error" + funcutil.GenRandomStr()
|
||||||
|
createColl(t, collName, rc)
|
||||||
// Nq must be in range [1, 16384].
|
// Nq must be in range [1, 16384].
|
||||||
task := getSearchTaskWithNq(t, 0)
|
task := getSearchTaskWithNq(t, collName, 0)
|
||||||
err = task.PreExecute(ctx)
|
err = task.PreExecute(ctx)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("bad nq 16385", func(t *testing.T) {
|
t.Run("bad nq 16385", func(t *testing.T) {
|
||||||
|
call := mockShowCollectionSuccess()
|
||||||
|
defer call.Unset()
|
||||||
|
collName := "test_bad_nq16385_error" + funcutil.GenRandomStr()
|
||||||
|
createColl(t, collName, rc)
|
||||||
|
|
||||||
// Nq must be in range [1, 16384].
|
// Nq must be in range [1, 16384].
|
||||||
task := getSearchTaskWithNq(t, 16384+1)
|
task := getSearchTaskWithNq(t, collName, 16384+1)
|
||||||
err = task.PreExecute(ctx)
|
err = task.PreExecute(ctx)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("collection not exist", func(t *testing.T) {
|
t.Run("collection not exist", func(t *testing.T) {
|
||||||
task := getSearchTask(t, collectionName)
|
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
|
||||||
|
task := getSearchTask(t, collName)
|
||||||
err = task.PreExecute(ctx)
|
err = task.PreExecute(ctx)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid IgnoreGrowing param", func(t *testing.T) {
|
t.Run("invalid IgnoreGrowing param", func(t *testing.T) {
|
||||||
|
call := mockShowCollectionSuccess()
|
||||||
|
defer call.Unset()
|
||||||
collName := "test_invalid_param" + funcutil.GenRandomStr()
|
collName := "test_invalid_param" + funcutil.GenRandomStr()
|
||||||
createColl(t, collName, rc)
|
createColl(t, collName, rc)
|
||||||
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
|
|
||||||
require.NoError(t, err)
|
|
||||||
qc.EXPECT().ShowCollections(
|
|
||||||
mock.Anything, mock.MatchedBy(func(req *querypb.ShowCollectionsRequest) bool { return req.CollectionIDs[0] == collID })).Return(&querypb.ShowCollectionsResponse{
|
|
||||||
Status: &successStatus,
|
|
||||||
CollectionIDs: []int64{collID},
|
|
||||||
InMemoryPercentages: []int64{100},
|
|
||||||
}, nil).Times(1)
|
|
||||||
|
|
||||||
task := getSearchTask(t, collName)
|
task := getSearchTask(t, collName)
|
||||||
task.request.SearchParams = getInvalidSearchParams(IgnoreGrowingKey)
|
task.request.SearchParams = getInvalidSearchParams(IgnoreGrowingKey)
|
||||||
|
@ -219,87 +240,23 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid collection name", func(t *testing.T) {
|
|
||||||
task := getSearchTask(t, collectionName)
|
|
||||||
createColl(t, collectionName, rc)
|
|
||||||
|
|
||||||
invalidCollNameTests := []struct {
|
|
||||||
inCollName string
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{"$", "invalid collection name $"},
|
|
||||||
{"0", "invalid collection name 0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range invalidCollNameTests {
|
|
||||||
t.Run(test.description, func(t *testing.T) {
|
|
||||||
task.request.CollectionName = test.inCollName
|
|
||||||
assert.Error(t, task.PreExecute(context.TODO()))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("invalid partition names", func(t *testing.T) {
|
|
||||||
task := getSearchTask(t, collectionName)
|
|
||||||
createColl(t, collectionName, rc)
|
|
||||||
|
|
||||||
invalidCollNameTests := []struct {
|
|
||||||
inPartNames []string
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{[]string{"$"}, "invalid partition name $"},
|
|
||||||
{[]string{"0"}, "invalid collection name 0"},
|
|
||||||
{[]string{"default", "$"}, "invalid empty partition name"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range invalidCollNameTests {
|
|
||||||
t.Run(test.description, func(t *testing.T) {
|
|
||||||
task.request.PartitionNames = test.inPartNames
|
|
||||||
assert.Error(t, task.PreExecute(context.TODO()))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("test checkIfLoaded error", func(t *testing.T) {
|
t.Run("test checkIfLoaded error", func(t *testing.T) {
|
||||||
collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr()
|
collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr()
|
||||||
createColl(t, collName, rc)
|
createColl(t, collName, rc)
|
||||||
_, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
|
|
||||||
require.NoError(t, err)
|
|
||||||
task := getSearchTask(t, collName)
|
task := getSearchTask(t, collName)
|
||||||
task.collectionName = collName
|
|
||||||
|
|
||||||
t.Run("show collection status unexpected error", func(t *testing.T) {
|
t.Run("show collection status unexpected error", func(t *testing.T) {
|
||||||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
call := mockShowCollectionFail()
|
||||||
Status: &commonpb.Status{
|
defer call.Unset()
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
||||||
Reason: "mock",
|
|
||||||
},
|
|
||||||
}, nil).Times(1)
|
|
||||||
|
|
||||||
assert.Error(t, task.PreExecute(ctx))
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("search with timeout", func(t *testing.T) {
|
t.Run("search with timeout", func(t *testing.T) {
|
||||||
|
call := mockShowCollectionSuccess()
|
||||||
|
defer call.Unset()
|
||||||
collName := "search_with_timeout" + funcutil.GenRandomStr()
|
collName := "search_with_timeout" + funcutil.GenRandomStr()
|
||||||
createColl(t, collName, rc)
|
createColl(t, collName, rc)
|
||||||
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil)
|
|
||||||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
|
|
||||||
Status: &successStatus,
|
|
||||||
CollectionIDs: []int64{collID},
|
|
||||||
InMemoryPercentages: []int64{100},
|
|
||||||
}, nil)
|
|
||||||
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
|
||||||
Base: &commonpb.MsgBase{
|
|
||||||
MsgType: commonpb.MsgType_LoadCollection,
|
|
||||||
},
|
|
||||||
CollectionID: collID,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, commonpb.ErrorCode_Success, status.GetErrorCode())
|
|
||||||
|
|
||||||
task := getSearchTask(t, collName)
|
task := getSearchTask(t, collName)
|
||||||
task.request.SearchParams = getValidSearchParams()
|
task.request.SearchParams = getValidSearchParams()
|
||||||
|
|
Loading…
Reference in New Issue