milvus/internal/proxy/task_search_test.go

1430 lines
37 KiB
Go

package proxy
import (
"context"
"errors"
"fmt"
"strconv"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
)
func TestSearchTask(t *testing.T) {
ctx := context.Background()
ctxCancel, cancel := context.WithCancel(ctx)
qt := &searchTask{
ctx: ctxCancel,
Condition: NewTaskCondition(context.TODO()),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
},
resultBuf: make(chan []*internalpb.SearchResults),
query: nil,
chMgr: nil,
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
}
// no result
go func() {
qt.resultBuf <- []*internalpb.SearchResults{}
}()
err := qt.PostExecute(context.TODO())
assert.NotNil(t, err)
// test trace context done
cancel()
err = qt.PostExecute(context.TODO())
assert.NotNil(t, err)
// error result
ctx = context.Background()
qt = &searchTask{
ctx: ctx,
Condition: NewTaskCondition(context.TODO()),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
},
resultBuf: make(chan []*internalpb.SearchResults),
query: nil,
chMgr: nil,
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
}
// no result
go func() {
result := internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "test",
},
}
results := make([]*internalpb.SearchResults, 1)
results[0] = &result
qt.resultBuf <- results
}()
err = qt.PostExecute(context.TODO())
assert.NotNil(t, err)
log.Debug("PostExecute failed" + err.Error())
// check result SlicedBlob
ctx = context.Background()
qt = &searchTask{
ctx: ctx,
Condition: NewTaskCondition(context.TODO()),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
},
resultBuf: make(chan []*internalpb.SearchResults),
query: nil,
chMgr: nil,
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
}
// no result
go func() {
result := internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "test",
},
SlicedBlob: nil,
}
results := make([]*internalpb.SearchResults, 1)
results[0] = &result
qt.resultBuf <- results
}()
err = qt.PostExecute(context.TODO())
assert.Nil(t, err)
assert.Equal(t, qt.result.Status.ErrorCode, commonpb.ErrorCode_Success)
// TODO, add decode result, reduce result test
}
func TestSearchTask_Channels(t *testing.T) {
var err error
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
prefix := "TestSearchTask_Channels"
collectionName := prefix + funcutil.GenRandomStr()
shardsNum := int32(2)
dbName := ""
int64Field := "int64"
floatVecField := "fvec"
dim := 128
task := &searchTask{
ctx: ctx,
query: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
chMgr: chMgr,
tr: timerecord.NewTimeRecorder("search"),
}
// collection not exist
_, err = task.getVChannels()
assert.Error(t, err)
_, err = task.getVChannels()
assert.Error(t, err)
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
_, err = task.getChannels()
assert.NoError(t, err)
_, err = task.getVChannels()
assert.NoError(t, err)
_ = chMgr.removeAllDMLStream()
chMgr.dmlChannelsMgr.getChannelsFunc = func(collectionID UniqueID) (map[vChan]pChan, error) {
return nil, errors.New("mock")
}
_, err = task.getChannels()
assert.Error(t, err)
_, err = task.getVChannels()
assert.Error(t, err)
}
func TestSearchTask_PreExecute(t *testing.T) {
var err error
Params.Init()
Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
prefix := "TestSearchTask_PreExecute"
collectionName := prefix + funcutil.GenRandomStr()
shardsNum := int32(2)
dbName := ""
int64Field := "int64"
floatVecField := "fvec"
dim := 128
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{},
query: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
chMgr: chMgr,
qc: qc,
tr: timerecord.NewTimeRecorder("search"),
}
assert.NoError(t, task.OnEnqueue())
// collection not exist
assert.Error(t, task.PreExecute(ctx))
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
collectionID, _ := globalMetaCache.GetCollectionID(ctx, collectionName)
// validateCollectionName
task.query.CollectionName = "$"
assert.Error(t, task.PreExecute(ctx))
task.query.CollectionName = collectionName
// Validate Partition
task.query.PartitionNames = []string{"$"}
assert.Error(t, task.PreExecute(ctx))
task.query.PartitionNames = nil
// mock show collections of QueryCoord
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, errors.New("mock")
})
assert.Error(t, task.PreExecute(ctx))
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
assert.Error(t, task.PreExecute(ctx))
qc.ResetShowCollectionsFunc()
// collection not loaded
assert.Error(t, task.PreExecute(ctx))
_, _ = qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
DbID: 0,
CollectionID: collectionID,
Schema: nil,
})
// no anns field
task.query.DslType = commonpb.DslType_BoolExprV1
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: floatVecField,
},
}
// no topk
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: "invalid",
},
}
// invalid topk
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: "10",
},
}
// no metric type
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: MetricTypeKey,
Value: distance.L2,
},
}
// no search params
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: int64Field,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: `{"nprobe": 10}`,
},
}
// invalid round_decimal
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: int64Field,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: `{"nprobe": 10}`,
},
{
Key: RoundDecimalKey,
Value: "invalid",
},
}
// invalid round_decimal
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: RoundDecimalKey,
Value: "-1",
},
}
// failed to create query plan
assert.Error(t, task.PreExecute(ctx))
task.query.SearchParams = []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: `{"nprobe": 10}`,
},
{
Key: RoundDecimalKey,
Value: "-1",
},
}
// search task with timeout
ctx1, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
// before preExecute
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctx1
assert.NoError(t, task.PreExecute(ctx))
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
// field not exist
task.query.OutputFields = []string{int64Field + funcutil.GenRandomStr()}
assert.Error(t, task.PreExecute(ctx))
// contain vector field
task.query.OutputFields = []string{floatVecField}
assert.Error(t, task.PreExecute(ctx))
task.query.OutputFields = []string{int64Field}
// partition
rc.showPartitionsFunc = func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
return nil, errors.New("mock")
}
assert.Error(t, task.PreExecute(ctx))
rc.showPartitionsFunc = nil
// TODO(dragondriver): test partition-related error
}
func TestSearchTask_Ts(t *testing.T) {
Params.Init()
task := &searchTask{
SearchRequest: &internalpb.SearchRequest{
Base: nil,
},
tr: timerecord.NewTimeRecorder("search"),
}
assert.NoError(t, task.OnEnqueue())
ts := Timestamp(time.Now().Nanosecond())
task.SetTs(ts)
assert.Equal(t, ts, task.BeginTs())
assert.Equal(t, ts, task.EndTs())
}
func TestSearchTask_Execute(t *testing.T) {
var err error
Params.Init()
Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
prefix := "TestSearchTask_Execute"
collectionName := prefix + funcutil.GenRandomStr()
shardsNum := int32(2)
dbName := ""
int64Field := "int64"
floatVecField := "fvec"
dim := 128
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 0,
Timestamp: uint64(time.Now().UnixNano()),
SourceID: 0,
},
},
query: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
result: &milvuspb.SearchResults{
Status: &commonpb.Status{},
Results: nil,
},
chMgr: chMgr,
qc: qc,
tr: timerecord.NewTimeRecorder("search"),
}
assert.NoError(t, task.OnEnqueue())
// collection not exist
assert.Error(t, task.PreExecute(ctx))
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
assert.NoError(t, task.Execute(ctx))
_ = chMgr.removeAllDQLStream()
query.f = func(collectionID UniqueID) (map[vChan]pChan, error) {
return nil, errors.New("mock")
}
assert.Error(t, task.Execute(ctx))
// TODO(dragondriver): cover getDQLStream
}
func genSearchResultData(nq int64, topk int64, ids []int64, scores []float32) *schemapb.SearchResultData {
return &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: nil,
Scores: scores,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
Topks: make([]int64, nq),
}
}
func TestSearchTask_Reduce(t *testing.T) {
const (
nq = 1
topk = 4
metricType = "L2"
)
t.Run("case1", func(t *testing.T) {
ids := []int64{1, 2, 3, 4}
scores := []float32{-1.0, -2.0, -3.0, -4.0}
data1 := genSearchResultData(nq, topk, ids, scores)
data2 := genSearchResultData(nq, topk, ids, scores)
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
assert.Nil(t, err)
assert.Equal(t, ids, res.Results.Ids.GetIntId().Data)
assert.Equal(t, []float32{1.0, 2.0, 3.0, 4.0}, res.Results.Scores)
})
t.Run("case2", func(t *testing.T) {
ids1 := []int64{1, 2, 3, 4}
scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
data1 := genSearchResultData(nq, topk, ids1, scores1)
data2 := genSearchResultData(nq, topk, ids2, scores2)
dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Results.Ids.GetIntId().Data)
})
}
func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) {
var err error
Params.Init()
Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestSearchTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
dim := 128
expr := fmt.Sprintf("%s > 0", testInt64Field)
nq := 10
topk := 10
roundDecimal := 7
nprobe := 10
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testFloatVecField: schemapb.DataType_FloatVector,
}
if enableMultipleVectorFields {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
DbID: 0,
CollectionID: collectionID,
Schema: nil,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
req := constructSearchRequest(dbName, collectionName,
expr,
testFloatVecField,
nq, dim, nprobe, topk, roundDecimal)
task := &searchTask{
Condition: NewTaskCondition(ctx),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
DbID: 0,
CollectionID: 0,
PartitionIDs: nil,
Dsl: "",
PlaceholderGroup: nil,
DslType: 0,
SerializedExprPlan: nil,
OutputFieldsId: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
},
ctx: ctx,
resultBuf: make(chan []*internalpb.SearchResults),
result: nil,
query: req,
chMgr: chMgr,
qc: qc,
tr: timerecord.NewTimeRecorder("search"),
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
constructSearchResulstData := func() *schemapb.SearchResultData {
resultData := &schemapb.SearchResultData{
NumQueries: int64(nq),
TopK: int64(topk),
Scores: make([]float32, nq*topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, nq*topk),
},
},
},
Topks: make([]int64, nq),
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk))
fieldID++
}
for i := 0; i < nq; i++ {
for j := 0; j < topk; j++ {
offset := i*topk + j
score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
resultData.Scores[offset] = score
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
}
resultData.Topks[i] = int64(topk)
}
return resultData
}
result1 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
resultData := constructSearchResulstData()
sliceBlob, err := proto.Marshal(resultData)
assert.NoError(t, err)
result1.SlicedBlob = sliceBlob
// result2.SliceBlob = nil, will be skipped in decode stage
result2 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
// send search result
task.resultBuf <- []*internalpb.SearchResults{result1, result2}
}
}
}
}()
assert.NoError(t, task.OnEnqueue())
assert.Error(t, task.PreExecute(ctx))
cancel()
wg.Wait()
}
func TestSearchTask_all(t *testing.T) {
var err error
Params.Init()
Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestSearchTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
dim := 128
expr := fmt.Sprintf("%s > 0", testInt64Field)
nq := 10
topk := 10
roundDecimal := 3
nprobe := 10
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testFloatVecField: schemapb.DataType_FloatVector,
}
if enableMultipleVectorFields {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
DbID: 0,
CollectionID: collectionID,
Schema: nil,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
req := constructSearchRequest(dbName, collectionName,
expr,
testFloatVecField,
nq, dim, nprobe, topk, roundDecimal)
task := &searchTask{
Condition: NewTaskCondition(ctx),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
DbID: 0,
CollectionID: 0,
PartitionIDs: nil,
Dsl: "",
PlaceholderGroup: nil,
DslType: 0,
SerializedExprPlan: nil,
OutputFieldsId: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
},
ctx: ctx,
resultBuf: make(chan []*internalpb.SearchResults),
result: nil,
query: req,
chMgr: chMgr,
qc: qc,
tr: timerecord.NewTimeRecorder("search"),
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
constructSearchResulstData := func() *schemapb.SearchResultData {
resultData := &schemapb.SearchResultData{
NumQueries: int64(nq),
TopK: int64(topk),
Scores: make([]float32, nq*topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, nq*topk),
},
},
},
Topks: make([]int64, nq),
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk))
fieldID++
}
for i := 0; i < nq; i++ {
for j := 0; j < topk; j++ {
offset := i*topk + j
score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
resultData.Scores[offset] = score
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
}
resultData.Topks[i] = int64(topk)
}
return resultData
}
result1 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
resultData := constructSearchResulstData()
sliceBlob, err := proto.Marshal(resultData)
assert.NoError(t, err)
result1.SlicedBlob = sliceBlob
// result2.SliceBlob = nil, will be skipped in decode stage
result2 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
// send search result
task.resultBuf <- []*internalpb.SearchResults{result1, result2}
}
}
}
}()
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
cancel()
wg.Wait()
}
func TestSearchTask_7803_reduce(t *testing.T) {
var err error
Params.Init()
Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestSearchTask_7803_reduce"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
roundDecimal := 3
nprobe := 10
schema := constructCollectionSchema(
int64Field,
floatVecField,
dim,
collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
DbID: 0,
CollectionID: collectionID,
Schema: nil,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
req := constructSearchRequest(dbName, collectionName,
expr,
floatVecField,
nq, dim, nprobe, topk, roundDecimal)
task := &searchTask{
Condition: NewTaskCondition(ctx),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
DbID: 0,
CollectionID: 0,
PartitionIDs: nil,
Dsl: "",
PlaceholderGroup: nil,
DslType: 0,
SerializedExprPlan: nil,
OutputFieldsId: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
},
ctx: ctx,
resultBuf: make(chan []*internalpb.SearchResults),
result: nil,
query: req,
chMgr: chMgr,
qc: qc,
tr: timerecord.NewTimeRecorder("search"),
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
constructSearchResulstData := func(invalidNum int) *schemapb.SearchResultData {
resultData := &schemapb.SearchResultData{
NumQueries: int64(nq),
TopK: int64(topk),
FieldsData: nil,
Scores: make([]float32, nq*topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, nq*topk),
},
},
},
Topks: make([]int64, nq),
}
for i := 0; i < nq; i++ {
for j := 0; j < topk; j++ {
offset := i*topk + j
if j >= invalidNum {
resultData.Scores[offset] = minFloat32
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = -1
} else {
score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
resultData.Scores[offset] = score
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
}
}
resultData.Topks[i] = int64(topk)
}
return resultData
}
result1 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
resultData := constructSearchResulstData(topk / 2)
sliceBlob, err := proto.Marshal(resultData)
assert.NoError(t, err)
result1.SlicedBlob = sliceBlob
result2 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
resultData2 := constructSearchResulstData(topk - topk/2)
sliceBlob2, err := proto.Marshal(resultData2)
assert.NoError(t, err)
result2.SlicedBlob = sliceBlob2
// send search result
task.resultBuf <- []*internalpb.SearchResults{result1, result2}
}
}
}
}()
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
cancel()
wg.Wait()
}