Proxy ut cover case that search task process the fields data (#7874)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/7881/head
dragondriver 2021-09-14 14:09:05 +08:00 committed by GitHub
parent 77719a0ddc
commit f122e93383
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 346 additions and 14 deletions

View File

@ -85,6 +85,7 @@ func newMockIDAllocatorInterface() idAllocatorInterface {
type mockGetChannelsService struct { type mockGetChannelsService struct {
collectionID2Channels map[UniqueID]map[vChan]pChan collectionID2Channels map[UniqueID]map[vChan]pChan
f getChannelsFuncType
} }
func newMockGetChannelsService() *mockGetChannelsService { func newMockGetChannelsService() *mockGetChannelsService {
@ -94,6 +95,10 @@ func newMockGetChannelsService() *mockGetChannelsService {
} }
func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) { func (m *mockGetChannelsService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) {
if m.f != nil {
return m.f(collectionID)
}
channels, ok := m.collectionID2Channels[collectionID] channels, ok := m.collectionID2Channels[collectionID]
if ok { if ok {
return channels, nil return channels, nil

View File

@ -394,8 +394,6 @@ func (coord *RootCoordMock) ShowCollections(ctx context.Context, req *milvuspb.S
CollectionNames: nil, CollectionNames: nil,
}, nil }, nil
} }
coord.collMtx.RLock()
defer coord.collMtx.RUnlock()
coord.collMtx.RLock() coord.collMtx.RLock()
defer coord.collMtx.RUnlock() defer coord.collMtx.RUnlock()

View File

@ -1656,14 +1656,14 @@ func (st *searchTask) Execute(ctx context.Context) error {
} }
} }
err = stream.Produce(&msgPack) err = stream.Produce(&msgPack)
log.Debug("proxy", zap.Int("length of searchMsg", len(msgPack.Msgs)))
log.Debug("proxy sent one searchMsg",
zap.Any("collectionID", st.CollectionID),
zap.Any("msgID", tsMsg.ID()),
)
if err != nil { if err != nil {
log.Debug("proxy", zap.String("send search request failed", err.Error())) log.Debug("proxy", zap.String("send search request failed", err.Error()))
} }
log.Debug("proxy sent one searchMsg",
zap.Any("collectionID", st.CollectionID),
zap.Any("msgID", tsMsg.ID()),
zap.Int("length of search msg", len(msgPack.Msgs)),
)
return err return err
} }

View File

@ -72,6 +72,110 @@ func constructCollectionSchema(
} }
} }
func constructCollectionSchemaWithAllType(
boolField, int32Field, int64Field, floatField, doubleField string,
floatVecField, binaryVecField string,
dim int,
collectionName string,
) *schemapb.CollectionSchema {
b := &schemapb.FieldSchema{
FieldID: 0,
Name: boolField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Bool,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
i32 := &schemapb.FieldSchema{
FieldID: 0,
Name: int32Field,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Int32,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
i64 := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
f := &schemapb.FieldSchema{
FieldID: 0,
Name: floatField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Float,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
d := &schemapb.FieldSchema{
FieldID: 0,
Name: doubleField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Double,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
bVec := &schemapb.FieldSchema{
FieldID: 0,
Name: binaryVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
b,
i32,
i64,
f,
d,
fVec,
bVec,
},
}
}
func constructPlaceholderGroup( func constructPlaceholderGroup(
nq, dim int, nq, dim int,
) *milvuspb.PlaceholderGroup { ) *milvuspb.PlaceholderGroup {
@ -1549,15 +1653,23 @@ func TestSearchTask_all(t *testing.T) {
prefix := "TestSearchTask_all" prefix := "TestSearchTask_all"
dbName := "" dbName := ""
collectionName := prefix + funcutil.GenRandomStr() collectionName := prefix + funcutil.GenRandomStr()
boolField := "bool"
int32Field := "int32"
int64Field := "int64" int64Field := "int64"
floatField := "float"
doubleField := "double"
floatVecField := "fvec" floatVecField := "fvec"
binaryVecField := "bvec"
fieldsLen := len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField})
dim := 128 dim := 128
expr := fmt.Sprintf("%s > 0", int64Field) expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10 nq := 10
topk := 10 topk := 10
nprobe := 10 nprobe := 10
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) schema := constructCollectionSchemaWithAllType(
boolField, int32Field, int64Field, floatField, doubleField,
floatVecField, binaryVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err) assert.NoError(t, err)
@ -1669,7 +1781,7 @@ func TestSearchTask_all(t *testing.T) {
resultData := &schemapb.SearchResultData{ resultData := &schemapb.SearchResultData{
NumQueries: int64(nq), NumQueries: int64(nq),
TopK: int64(topk), TopK: int64(topk),
FieldsData: nil, FieldsData: make([]*schemapb.FieldData, fieldsLen),
Scores: make([]float32, nq*topk), Scores: make([]float32, nq*topk),
Ids: &schemapb.IDs{ Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{ IdField: &schemapb.IDs_IntId{
@ -1681,10 +1793,110 @@ func TestSearchTask_all(t *testing.T) {
Topks: make([]int64, nq), Topks: make([]int64, nq),
} }
// ids := make([]int64, topk) resultData.FieldsData[0] = &schemapb.FieldData{
// for i := 0; i < topk; i++ { Type: schemapb.DataType_Bool,
// ids[i] = int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) FieldName: boolField,
// } Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: generateBoolArray(nq * topk),
},
},
},
},
FieldId: common.StartOfUserFieldID + 0,
}
resultData.FieldsData[1] = &schemapb.FieldData{
Type: schemapb.DataType_Int32,
FieldName: int32Field,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: generateInt32Array(nq * topk),
},
},
},
},
FieldId: common.StartOfUserFieldID + 1,
}
resultData.FieldsData[2] = &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: int64Field,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: generateInt64Array(nq * topk),
},
},
},
},
FieldId: common.StartOfUserFieldID + 2,
}
resultData.FieldsData[3] = &schemapb.FieldData{
Type: schemapb.DataType_Float,
FieldName: floatField,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: generateFloat32Array(nq * topk),
},
},
},
},
FieldId: common.StartOfUserFieldID + 3,
}
resultData.FieldsData[4] = &schemapb.FieldData{
Type: schemapb.DataType_Double,
FieldName: doubleField,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: generateFloat64Array(nq * topk),
},
},
},
},
FieldId: common.StartOfUserFieldID + 4,
}
resultData.FieldsData[5] = &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: doubleField,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: generateFloatVectors(nq*topk, dim),
},
},
},
},
FieldId: common.StartOfUserFieldID + 5,
}
resultData.FieldsData[6] = &schemapb.FieldData{
Type: schemapb.DataType_BinaryVector,
FieldName: doubleField,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: generateBinaryVectors(nq*topk, dim),
},
},
},
FieldId: common.StartOfUserFieldID + 6,
}
for i := 0; i < nq; i++ { for i := 0; i < nq; i++ {
for j := 0; j < topk; j++ { for j := 0; j < topk; j++ {
@ -1727,8 +1939,32 @@ func TestSearchTask_all(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
result1.SlicedBlob = sliceBlob result1.SlicedBlob = sliceBlob
// result2.SliceBlob = nil, will be skipped in decode stage
result2 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
// send search result // send search result
task.resultBuf <- []*internalpb.SearchResults{result1} task.resultBuf <- []*internalpb.SearchResults{result1, result2}
} }
} }
} }
@ -2079,3 +2315,96 @@ func TestSearchTask_PreExecute(t *testing.T) {
// TODO(dragondriver): test partition-related error // TODO(dragondriver): test partition-related error
} }
func TestSearchTask_Execute(t *testing.T) {
var err error
Params.Init()
Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
prefix := "TestSearchTask_Execute"
collectionName := prefix + funcutil.GenRandomStr()
shardsNum := int32(2)
dbName := ""
int64Field := "int64"
floatVecField := "fvec"
dim := 128
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 0,
Timestamp: uint64(time.Now().UnixNano()),
SourceID: 0,
},
},
query: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
result: &milvuspb.SearchResults{
Status: &commonpb.Status{},
Results: nil,
},
chMgr: chMgr,
qc: qc,
}
assert.NoError(t, task.OnEnqueue())
// collection not exist
assert.Error(t, task.PreExecute(ctx))
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
assert.NoError(t, task.Execute(ctx))
_ = chMgr.removeAllDQLStream()
query.f = func(collectionID UniqueID) (map[vChan]pChan, error) {
return nil, errors.New("mock")
}
assert.Error(t, task.Execute(ctx))
// TODO(dragondriver): cover getDQLStream
}