Refactor querynode unittest (#15929)

Signed-off-by: zhenshan.cao <zhenshan.cao@zilliz.com>
pull/15941/head
zhenshan.cao 2022-03-08 17:39:58 +08:00 committed by GitHub
parent 74f66dce3b
commit dff08dbf47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 49 deletions

View File

@ -150,6 +150,10 @@ test-datanode:
@echo "Running go unittests..."
go test -race -coverpkg=./... -coverprofile=profile.out -covermode=atomic -timeout 5m github.com/milvus-io/milvus/internal/datanode -v
test-querynode:
@echo "Running go unittests..."
go test -race -coverpkg=./... -coverprofile=profile.out -covermode=atomic -timeout 5m github.com/milvus-io/milvus/internal/querynode -v
test-querycoord:
@echo "Running go unittests..."
go test -race -coverpkg=./... -coverprofile=profile.out -covermode=atomic -timeout 5m github.com/milvus-io/milvus/internal/querycoord -v

View File

@ -78,7 +78,8 @@ func benchmarkQueryCollectionSearch(nq int, b *testing.B) {
assert.Equal(b, seg.getMemSize(), int64(expectSize))
// warming up
msgTmp, err := genSearchMsg(10)
msgTmp, err := genSearchMsg(10, IndexFaissIDMap)
assert.NoError(b, err)
for j := 0; j < 10000; j++ {
err = queryCollection.search(msgTmp)
@ -87,7 +88,7 @@ func benchmarkQueryCollectionSearch(nq int, b *testing.B) {
msgs := make([]*msgstream2.SearchMsg, maxNQ/nq)
for i := 0; i < maxNQ/nq; i++ {
msg, err := genSearchMsg(nq)
msg, err := genSearchMsg(nq, IndexFaissIDMap)
assert.NoError(b, err)
msgs[i] = msg
}
@ -154,7 +155,7 @@ func benchmarkQueryCollectionSearchIndex(nq int, indexType string, b *testing.B)
assert.Equal(b, seg.getMemSize(), int64(expectSize))
// warming up
msgTmp, err := genSearchMsg(10)
msgTmp, err := genSearchMsg(10, indexType)
assert.NoError(b, err)
for j := 0; j < 10000; j++ {
err = queryCollection.search(msgTmp)
@ -163,7 +164,7 @@ func benchmarkQueryCollectionSearchIndex(nq int, indexType string, b *testing.B)
msgs := make([]*msgstream2.SearchMsg, maxNQ/nq)
for i := 0; i < maxNQ/nq; i++ {
msg, err := genSearchMsg(nq)
msg, err := genSearchMsg(nq, indexType)
assert.NoError(b, err)
msgs[i] = msg
}
@ -187,15 +188,22 @@ func benchmarkQueryCollectionSearchIndex(nq int, indexType string, b *testing.B)
}
}
func BenchmarkSearch_NQ1(b *testing.B) { benchmarkQueryCollectionSearch(1, b) }
func BenchmarkSearch_NQ10(b *testing.B) { benchmarkQueryCollectionSearch(10, b) }
func BenchmarkSearch_NQ100(b *testing.B) { benchmarkQueryCollectionSearch(100, b) }
func BenchmarkSearch_NQ1000(b *testing.B) { benchmarkQueryCollectionSearch(1000, b) }
func BenchmarkSearch_NQ10000(b *testing.B) { benchmarkQueryCollectionSearch(10000, b) }
func BenchmarkSearch_NQ1(b *testing.B) { benchmarkQueryCollectionSearch(1, b) }
//func BenchmarkSearch_NQ10(b *testing.B) { benchmarkQueryCollectionSearch(10, b) }
//func BenchmarkSearch_NQ100(b *testing.B) { benchmarkQueryCollectionSearch(100, b) }
//func BenchmarkSearch_NQ1000(b *testing.B) { benchmarkQueryCollectionSearch(1000, b) }
//func BenchmarkSearch_NQ10000(b *testing.B) { benchmarkQueryCollectionSearch(10000, b) }
func BenchmarkSearch_HNSW_NQ1(b *testing.B) {
benchmarkQueryCollectionSearchIndex(1, IndexHNSW, b)
}
func BenchmarkSearch_IVFFLAT_NQ1(b *testing.B) {
benchmarkQueryCollectionSearchIndex(1, IndexFaissIVFFlat, b)
}
/*
func BenchmarkSearch_IVFFLAT_NQ10(b *testing.B) {
benchmarkQueryCollectionSearchIndex(10, IndexFaissIVFFlat, b)
}
@ -208,3 +216,4 @@ func BenchmarkSearch_IVFFLAT_NQ1000(b *testing.B) {
func BenchmarkSearch_IVFFLAT_NQ10000(b *testing.B) {
benchmarkQueryCollectionSearchIndex(10000, IndexFaissIVFFlat, b)
}
*/

View File

@ -32,7 +32,7 @@ func TestHistorical_Search(t *testing.T) {
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
_, _, _, err = his.search(searchReqs, defaultCollectionID, []UniqueID{defaultPartitionID}, plan, Timestamp(0))
@ -44,7 +44,7 @@ func TestHistorical_Search(t *testing.T) {
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
err = his.replica.removeCollection(defaultCollectionID)
@ -59,7 +59,7 @@ func TestHistorical_Search(t *testing.T) {
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
err = his.replica.removeCollection(defaultCollectionID)
@ -74,7 +74,7 @@ func TestHistorical_Search(t *testing.T) {
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
col, err := his.replica.getCollectionByID(defaultCollectionID)
@ -93,7 +93,7 @@ func TestHistorical_Search(t *testing.T) {
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
err = his.replica.removePartition(defaultPartitionID)
@ -110,7 +110,7 @@ func TestHistorical_Search(t *testing.T) {
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
col, err := his.replica.getCollectionByID(defaultCollectionID)

View File

@ -63,6 +63,7 @@ const (
defaultRoundDecimal = int64(6)
defaultDim = 128
defaultNProb = 10
defaultEf = 10
defaultMetricType = L2
defaultNQ = 10
@ -600,6 +601,7 @@ func genSimpleInsertDataSchema() *schemapb.CollectionSchema {
fieldTimestamp := genConstantField(timestampField)
fieldVec := genFloatVectorField(simpleVecField)
fieldInt := genConstantField(simpleConstField)
fieldPK := genPKField(simplePKField)
schema := schemapb.CollectionSchema{ // schema for insertData
Name: defaultCollectionName,
@ -607,6 +609,7 @@ func genSimpleInsertDataSchema() *schemapb.CollectionSchema {
Fields: []*schemapb.FieldSchema{
fieldUID,
fieldTimestamp,
fieldPK,
fieldVec,
fieldInt,
},
@ -1244,7 +1247,7 @@ func genSimpleStreaming(ctx context.Context, tSafeReplica TSafeReplicaInterface)
// ---------- unittest util functions ----------
// functions of messages and requests
func genDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
func genIVFFlatDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
nProbStr := strconv.Itoa(nProb)
@ -1272,9 +1275,74 @@ func genDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecim
"\n } \n } \n } \n }", nil
}
func genSimpleDSL() (string, error) {
func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
efStr := strconv.Itoa(ef)
topKStr := strconv.FormatInt(topK, 10)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"ef\": " + efStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
topKStr := strconv.FormatInt(topK, 10)
nProbStr := strconv.Itoa(defaultNProb)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genDSLByIndexType(indexType string) (string, error) {
schema := genSimpleSegCoreSchema()
return genDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
if indexType == IndexFaissIDMap { // float vector
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIDMap {
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissIVFFlat {
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIVFFlat { // binary vector
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexHNSW {
return genHNSWDSL(schema, defaultEf, defaultTopK, defaultRoundDecimal)
}
return "", fmt.Errorf("Invalid indexType")
}
func genPlaceHolderGroup(nq int) ([]byte, error) {
@ -1312,13 +1380,13 @@ func genSimplePlaceHolderGroup() ([]byte, error) {
return genPlaceHolderGroup(defaultNQ)
}
func genSimpleSearchPlanAndRequests() (*SearchPlan, []*searchRequest, error) {
func genSimpleSearchPlanAndRequests(indexType string) (*SearchPlan, []*searchRequest, error) {
schema := genSimpleSegCoreSchema()
collection := newCollection(defaultCollectionID, schema)
var plan *SearchPlan
var err error
sm, err := genSimpleSearchMsg()
sm, err := genSimpleSearchMsg(indexType)
if err != nil {
return nil, nil, err
}
@ -1402,12 +1470,12 @@ func genSimpleRetrievePlan() (*RetrievePlan, error) {
return plan, err
}
func genSearchRequest(nq int) (*internalpb.SearchRequest, error) {
func genSearchRequest(nq int, indexType string) (*internalpb.SearchRequest, error) {
placeHolder, err := genPlaceHolderGroup(nq)
if err != nil {
return nil, err
}
simpleDSL, err := genSimpleDSL()
simpleDSL, err := genDSLByIndexType(indexType)
if err != nil {
return nil, err
}
@ -1421,8 +1489,8 @@ func genSearchRequest(nq int) (*internalpb.SearchRequest, error) {
}, nil
}
func genSimpleSearchRequest() (*internalpb.SearchRequest, error) {
return genSearchRequest(defaultNQ)
func genSimpleSearchRequest(indexType string) (*internalpb.SearchRequest, error) {
return genSearchRequest(defaultNQ, indexType)
}
func genSimpleRetrieveRequest() (*internalpb.RetrieveRequest, error) {
@ -1444,8 +1512,8 @@ func genSimpleRetrieveRequest() (*internalpb.RetrieveRequest, error) {
}, nil
}
func genSearchMsg(nq int) (*msgstream.SearchMsg, error) {
req, err := genSearchRequest(nq)
func genSearchMsg(nq int, indexType string) (*msgstream.SearchMsg, error) {
req, err := genSearchRequest(nq, indexType)
if err != nil {
return nil, err
}
@ -1457,8 +1525,8 @@ func genSearchMsg(nq int) (*msgstream.SearchMsg, error) {
return msg, nil
}
func genSimpleSearchMsg() (*msgstream.SearchMsg, error) {
req, err := genSimpleSearchRequest()
func genSimpleSearchMsg(indexType string) (*msgstream.SearchMsg, error) {
req, err := genSimpleSearchRequest(indexType)
if err != nil {
return nil, err
}
@ -1501,7 +1569,7 @@ func produceSimpleSearchMsg(ctx context.Context, queryChannel Channel) error {
stream.AsProducer([]string{queryChannel})
stream.Start()
defer stream.Close()
msg, err := genSimpleSearchMsg()
msg, err := genSimpleSearchMsg(IndexFaissIDMap)
if err != nil {
return err
}

View File

@ -252,7 +252,7 @@ func TestQueryCollection_unsolvedMsg(t *testing.T) {
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
assert.NoError(t, err)
qm, err := genSimpleSearchMsg()
qm, err := genSimpleSearchMsg(IndexFaissIDMap)
assert.NoError(t, err)
queryCollection.addToUnsolvedMsg(qm)
@ -298,7 +298,7 @@ func TestQueryCollection_consumeQuery(t *testing.T) {
}
t.Run("consume search", func(t *testing.T) {
msg, err := genSimpleSearchMsg()
msg, err := genSimpleSearchMsg(IndexFaissIDMap)
assert.NoError(t, err)
runConsumeQuery(msg)
})
@ -598,7 +598,7 @@ func TestQueryCollection_doUnsolvedQueryMsg(t *testing.T) {
go queryCollection.doUnsolvedQueryMsg()
msg, err := genSimpleSearchMsg()
msg, err := genSimpleSearchMsg(IndexFaissIDMap)
assert.NoError(t, err)
queryCollection.addToUnsolvedMsg(msg)
@ -622,7 +622,7 @@ func TestQueryCollection_doUnsolvedQueryMsg(t *testing.T) {
go queryCollection.doUnsolvedQueryMsg()
msg, err := genSimpleSearchMsg()
msg, err := genSimpleSearchMsg(IndexFaissIDMap)
assert.NoError(t, err)
msg.TimeoutTimestamp = tsoutil.GetCurrentTime() - Timestamp(time.Second<<18)
queryCollection.addToUnsolvedMsg(msg)
@ -653,7 +653,7 @@ func TestQueryCollection_search(t *testing.T) {
err = queryCollection.historical.replica.removeSegment(defaultSegmentID)
assert.NoError(t, err)
msg, err := genSimpleSearchMsg()
msg, err := genSimpleSearchMsg(IndexFaissIDMap)
assert.NoError(t, err)
err = queryCollection.search(msg)
@ -805,7 +805,7 @@ func TestQueryCollection_search_while_release(t *testing.T) {
})
queryCollection.sessionManager = sessionManager
msg, err := genSimpleSearchMsg()
msg, err := genSimpleSearchMsg(IndexFaissIDMap)
assert.NoError(t, err)
// To prevent data race in search trackCtx
@ -850,7 +850,7 @@ func TestQueryCollection_search_while_release(t *testing.T) {
})
queryCollection.sessionManager = sessionManager
msg, err := genSimpleSearchMsg()
msg, err := genSimpleSearchMsg(IndexFaissIDMap)
assert.NoError(t, err)
// To prevent data race in search trackCtx

View File

@ -374,13 +374,6 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) {
defer cancel()
schema := genSimpleInsertDataSchema()
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: UniqueID(102),
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
})
fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema)
assert.NoError(t, err)

View File

@ -45,7 +45,7 @@ func TestStreaming_search(t *testing.T) {
assert.NoError(t, err)
defer streaming.close()
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
res, _, _, err := streaming.search(searchReqs,
@ -64,7 +64,7 @@ func TestStreaming_search(t *testing.T) {
assert.NoError(t, err)
defer streaming.close()
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
res, _, _, err := streaming.search(searchReqs,
@ -83,7 +83,7 @@ func TestStreaming_search(t *testing.T) {
assert.NoError(t, err)
defer streaming.close()
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
col, err := streaming.replica.getCollectionByID(defaultCollectionID)
@ -109,7 +109,7 @@ func TestStreaming_search(t *testing.T) {
assert.NoError(t, err)
defer streaming.close()
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
col, err := streaming.replica.getCollectionByID(defaultCollectionID)
@ -134,7 +134,7 @@ func TestStreaming_search(t *testing.T) {
assert.NoError(t, err)
defer streaming.close()
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
err = streaming.replica.removePartition(defaultPartitionID)
@ -156,7 +156,7 @@ func TestStreaming_search(t *testing.T) {
assert.NoError(t, err)
defer streaming.close()
plan, searchReqs, err := genSimpleSearchPlanAndRequests()
plan, searchReqs, err := genSimpleSearchPlanAndRequests(IndexFaissIDMap)
assert.NoError(t, err)
seg, err := streaming.replica.getSegmentByID(defaultSegmentID)