milvus/internal/proxy/task_search_test.go

1556 lines
49 KiB
Go

package proxy
import (
"context"
"errors"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
testShardsNum = int32(2)
)
func TestSearchTask_PostExecute(t *testing.T) {
t.Run("Test empty result", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
qt := &searchTask{
ctx: ctx,
Condition: NewTaskCondition(context.TODO()),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: Params.ProxyCfg.GetNodeID(),
},
},
request: nil,
qc: nil,
tr: timerecord.NewTimeRecorder("search"),
resultBuf: make(chan *internalpb.SearchResults, 10),
toReduceResults: make([]*internalpb.SearchResults, 0),
}
// no result
qt.resultBuf <- &internalpb.SearchResults{}
mockctx, mockcancel := context.WithCancel(ctx)
qt.runningGroupCtx = mockctx
mockcancel()
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, qt.result.Status.ErrorCode, commonpb.ErrorCode_Success)
})
}
func createColl(t *testing.T, name string, rc types.RootCoord) {
schema := constructCollectionSchema(testInt64Field, testFloatVecField, testVecDim, name)
marshaledSchema, err := proto.Marshal(schema)
require.NoError(t, err)
ctx := context.TODO()
createColT := &createCollectionTask{
Condition: NewTaskCondition(context.TODO()),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: name,
Schema: marshaledSchema,
ShardsNum: testShardsNum,
},
ctx: ctx,
rootCoord: rc,
}
require.NoError(t, createColT.OnEnqueue())
require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx))
}
func getValidSearchParams() []*commonpb.KeyValuePair {
return []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: testFloatVecField,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: `{"nprobe": 10}`,
},
{
Key: RoundDecimalKey,
Value: "-1",
}}
}
func TestSearchTask_PreExecute(t *testing.T) {
var err error
Params.Init()
var (
rc = NewRootCoordMock()
qc = NewQueryCoordMock()
ctx = context.TODO()
collectionName = t.Name() + funcutil.GenRandomStr()
)
err = rc.Start()
defer rc.Stop()
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()
defer qc.Stop()
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string) *searchTask {
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{
CollectionName: collName,
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
return task
}
t.Run("collection not exist", func(t *testing.T) {
task := getSearchTask(t, collectionName)
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("invalid collection name", func(t *testing.T) {
task := getSearchTask(t, collectionName)
createColl(t, collectionName, rc)
invalidCollNameTests := []struct {
inCollName string
description string
}{
{"$", "invalid collection name $"},
{"0", "invalid collection name 0"},
}
for _, test := range invalidCollNameTests {
t.Run(test.description, func(t *testing.T) {
task.request.CollectionName = test.inCollName
assert.Error(t, task.PreExecute(context.TODO()))
})
}
})
t.Run("invalid partition names", func(t *testing.T) {
task := getSearchTask(t, collectionName)
createColl(t, collectionName, rc)
invalidCollNameTests := []struct {
inPartNames []string
description string
}{
{[]string{"$"}, "invalid partition name $"},
{[]string{"0"}, "invalid collection name 0"},
{[]string{"default", "$"}, "invalid empty partition name"},
}
for _, test := range invalidCollNameTests {
t.Run(test.description, func(t *testing.T) {
task.request.PartitionNames = test.inPartNames
assert.Error(t, task.PreExecute(context.TODO()))
})
}
})
t.Run("test checkIfLoaded error", func(t *testing.T) {
collName := "test_checkIfLoaded_error" + funcutil.GenRandomStr()
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
task := getSearchTask(t, collName)
task.collectionName = collName
t.Run("show collection err", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, errors.New("mock")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show collection status unexpected error", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
assert.Error(t, task.PreExecute(ctx))
qc.ResetShowCollectionsFunc()
})
t.Run("show partition error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock",
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partition status unexpected error", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.Error(t, err)
assert.False(t, loaded)
})
t.Run("show partitions success", func(t *testing.T) {
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}, nil
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{1})
assert.NoError(t, err)
assert.True(t, loaded)
qc.ResetShowPartitionsFunc()
})
t.Run("show collection success but not loaded", func(t *testing.T) {
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []UniqueID{collID},
InMemoryPercentages: []int64{0},
}, nil
})
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock error")
})
loaded, err := task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("mock error")
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.Error(t, err)
assert.False(t, loaded)
qc.SetShowPartitionsFunc(func(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
PartitionIDs: []UniqueID{1},
}, nil
})
loaded, err = task.checkIfLoaded(collID, []UniqueID{})
assert.NoError(t, err)
assert.True(t, loaded)
})
qc.ResetShowCollectionsFunc()
qc.ResetShowPartitionsFunc()
})
t.Run("invalid key value pairs", func(t *testing.T) {
spNoTopk := []*commonpb.KeyValuePair{{
Key: AnnsFieldKey,
Value: testFloatVecField}}
spInvalidTopk := append(spNoTopk, &commonpb.KeyValuePair{
Key: TopKKey,
Value: "invalid",
})
spNoMetricType := append(spNoTopk, &commonpb.KeyValuePair{
Key: TopKKey,
Value: "10",
})
spNoSearchParams := append(spNoMetricType, &commonpb.KeyValuePair{
Key: MetricTypeKey,
Value: distance.L2,
})
spNoRoundDecimal := append(spNoSearchParams, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Value: `{"nprobe": 10}`,
})
spInvalidRoundDecimal := append(spNoRoundDecimal, &commonpb.KeyValuePair{
Key: RoundDecimalKey,
Value: "invalid",
})
tests := []struct {
description string
invalidParams []*commonpb.KeyValuePair
}{
{"No_topk", spNoTopk},
{"Invalid_topk", spInvalidTopk},
{"No_Metric_type", spNoMetricType},
{"No_search_params", spNoSearchParams},
{"no_round_decimal", spNoRoundDecimal},
{"Invalid_round_decimal", spInvalidRoundDecimal},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
collName := "collection_" + test.description
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
task := getSearchTask(t, collName)
task.request.DslType = commonpb.DslType_BoolExprV1
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: collID,
})
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.GetErrorCode())
assert.Error(t, task.PreExecute(ctx))
})
}
})
t.Run("search with timeout", func(t *testing.T) {
collName := "search_with_timeout" + funcutil.GenRandomStr()
createColl(t, collName, rc)
collID, err := globalMetaCache.GetCollectionID(context.TODO(), collName)
require.NoError(t, err)
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
},
CollectionID: collID,
})
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.GetErrorCode())
task := getSearchTask(t, collName)
task.request.SearchParams = getValidSearchParams()
task.request.DslType = commonpb.DslType_BoolExprV1
ctxTimeout, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
task.ctx = ctxTimeout
assert.NoError(t, task.PreExecute(ctx))
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
// field not exist
task.ctx = context.TODO()
task.request.OutputFields = []string{testInt64Field + funcutil.GenRandomStr()}
assert.Error(t, task.PreExecute(ctx))
// contain vector field
task.request.OutputFields = []string{testFloatVecField}
assert.Error(t, task.PreExecute(ctx))
})
}
func TestSearchTaskV2_Execute(t *testing.T) {
Params.Init()
var (
err error
rc = NewRootCoordMock()
qc = NewQueryCoordMock()
ctx = context.TODO()
collectionName = t.Name() + funcutil.GenRandomStr()
)
err = rc.Start()
require.NoError(t, err)
defer rc.Stop()
mgr := newShardClientMgr()
err = InitMetaCache(rc, qc, mgr)
require.NoError(t, err)
err = qc.Start()
require.NoError(t, err)
defer qc.Stop()
task := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
Timestamp: uint64(time.Now().UnixNano()),
},
},
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
result: &milvuspb.SearchResults{
Status: &commonpb.Status{},
},
qc: qc,
tr: timerecord.NewTimeRecorder("search"),
}
require.NoError(t, task.OnEnqueue())
createColl(t, collectionName, rc)
}
func genSearchResultData(nq int64, topk int64, ids []int64, scores []float32) *schemapb.SearchResultData {
return &schemapb.SearchResultData{
NumQueries: nq,
TopK: topk,
FieldsData: nil,
Scores: scores,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
},
Topks: make([]int64, nq),
}
}
func TestSearchTask_Ts(t *testing.T) {
Params.Init()
task := &searchTask{
SearchRequest: &internalpb.SearchRequest{},
tr: timerecord.NewTimeRecorder("test-search"),
}
require.NoError(t, task.OnEnqueue())
ts := Timestamp(time.Now().Nanosecond())
task.SetTs(ts)
assert.Equal(t, ts, task.BeginTs())
assert.Equal(t, ts, task.EndTs())
}
func TestSearchTask_Reduce(t *testing.T) {
// const (
// nq = 1
// topk = 4
// metricType = "L2"
// )
// t.Run("case1", func(t *testing.T) {
// ids := []int64{1, 2, 3, 4}
// scores := []float32{-1.0, -2.0, -3.0, -4.0}
// data1 := genSearchResultData(nq, topk, ids, scores)
// data2 := genSearchResultData(nq, topk, ids, scores)
// dataArray := make([]*schemapb.SearchResultData, 0)
// dataArray = append(dataArray, data1)
// dataArray = append(dataArray, data2)
// res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
// assert.Nil(t, err)
// assert.Equal(t, ids, res.Results.Ids.GetIntId().Data)
// assert.Equal(t, []float32{1.0, 2.0, 3.0, 4.0}, res.Results.Scores)
// })
// t.Run("case2", func(t *testing.T) {
// ids1 := []int64{1, 2, 3, 4}
// scores1 := []float32{-1.0, -2.0, -3.0, -4.0}
// ids2 := []int64{5, 1, 3, 4}
// scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
// data1 := genSearchResultData(nq, topk, ids1, scores1)
// data2 := genSearchResultData(nq, topk, ids2, scores2)
// dataArray := make([]*schemapb.SearchResultData, 0)
// dataArray = append(dataArray, data1)
// dataArray = append(dataArray, data2)
// res, err := reduceSearchResultData(dataArray, nq, topk, metricType)
// assert.Nil(t, err)
// assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Results.Ids.GetIntId().Data)
// })
}
func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) {
// var err error
//
// Params.Init()
// Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
//
// rc := NewRootCoordMock()
// rc.Start()
// defer rc.Stop()
//
// ctx := context.Background()
//
// err = InitMetaCache(rc)
// assert.NoError(t, err)
//
// shardsNum := int32(2)
// prefix := "TestSearchTaskV2_all"
// collectionName := prefix + funcutil.GenRandomStr()
//
// dim := 128
// expr := fmt.Sprintf("%s > 0", testInt64Field)
// nq := 10
// topk := 10
// roundDecimal := 7
// nprobe := 10
//
// fieldName2Types := map[string]schemapb.DataType{
// testBoolField: schemapb.DataType_Bool,
// testInt32Field: schemapb.DataType_Int32,
// testInt64Field: schemapb.DataType_Int64,
// testFloatField: schemapb.DataType_Float,
// testDoubleField: schemapb.DataType_Double,
// testFloatVecField: schemapb.DataType_FloatVector,
// }
// if enableMultipleVectorFields {
// fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
// }
// schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
// marshaledSchema, err := proto.Marshal(schema)
// assert.NoError(t, err)
//
// createColT := &createCollectionTask{
// Condition: NewTaskCondition(ctx),
// CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
// Base: nil,
// CollectionName: collectionName,
// Schema: marshaledSchema,
// ShardsNum: shardsNum,
// },
// ctx: ctx,
// rootCoord: rc,
// result: nil,
// schema: nil,
// }
//
// assert.NoError(t, createColT.OnEnqueue())
// assert.NoError(t, createColT.PreExecute(ctx))
// assert.NoError(t, createColT.Execute(ctx))
// assert.NoError(t, createColT.PostExecute(ctx))
//
// dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
// query := newMockGetChannelsService()
// factory := newSimpleMockMsgStreamFactory()
//
// collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
// assert.NoError(t, err)
//
// qc := NewQueryCoordMock()
// qc.Start()
// defer qc.Stop()
// status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_LoadCollection,
// MsgID: 0,
// Timestamp: 0,
// SourceID: Params.ProxyCfg.GetNodeID(),
// },
// DbID: 0,
// CollectionID: collectionID,
// Schema: nil,
// })
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
//
// req := constructSearchRequest("", collectionName,
// expr,
// testFloatVecField,
// nq, dim, nprobe, topk, roundDecimal)
//
// task := &searchTaskV2{
// Condition: NewTaskCondition(ctx),
// SearchRequest: &internalpb.SearchRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_Search,
// MsgID: 0,
// Timestamp: 0,
// SourceID: Params.ProxyCfg.GetNodeID(),
// },
// ResultChannelID: strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
// DbID: 0,
// CollectionID: 0,
// PartitionIDs: nil,
// Dsl: "",
// PlaceholderGroup: nil,
// DslType: 0,
// SerializedExprPlan: nil,
// OutputFieldsId: nil,
// TravelTimestamp: 0,
// GuaranteeTimestamp: 0,
// },
// ctx: ctx,
// resultBuf: make(chan *internalpb.SearchResults, 10),
// result: nil,
// request: req,
// qc: qc,
// tr: timerecord.NewTimeRecorder("search"),
// }
//
// // simple mock for query node
// // TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
//
//
// var wg sync.WaitGroup
// wg.Add(1)
// consumeCtx, cancel := context.WithCancel(ctx)
// go func() {
// defer wg.Done()
// for {
// select {
// case <-consumeCtx.Done():
// return
// case pack, ok := <-stream.Chan():
// assert.True(t, ok)
// if pack == nil {
// continue
// }
//
// for _, msg := range pack.Msgs {
// _, ok := msg.(*msgstream.SearchMsg)
// assert.True(t, ok)
// // TODO(dragondriver): construct result according to the request
//
// constructSearchResulstData := func() *schemapb.SearchResultData {
// resultData := &schemapb.SearchResultData{
// NumQueries: int64(nq),
// TopK: int64(topk),
// Scores: make([]float32, nq*topk),
// Ids: &schemapb.IDs{
// IdField: &schemapb.IDs_IntId{
// IntId: &schemapb.LongArray{
// Data: make([]int64, nq*topk),
// },
// },
// },
// Topks: make([]int64, nq),
// }
//
// fieldID := common.StartOfUserFieldID
// for fieldName, dataType := range fieldName2Types {
// resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk))
// fieldID++
// }
//
// for i := 0; i < nq; i++ {
// for j := 0; j < topk; j++ {
// offset := i*topk + j
// score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
// id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
// resultData.Scores[offset] = score
// resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
// }
// resultData.Topks[i] = int64(topk)
// }
//
// return resultData
// }
//
// result1 := &internalpb.SearchResults{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_SearchResult,
// MsgID: 0,
// Timestamp: 0,
// SourceID: 0,
// },
// Status: &commonpb.Status{
// ErrorCode: commonpb.ErrorCode_Success,
// Reason: "",
// },
// ResultChannelID: "",
// MetricType: distance.L2,
// NumQueries: int64(nq),
// TopK: int64(topk),
// SealedSegmentIDsSearched: nil,
// ChannelIDsSearched: nil,
// GlobalSealedSegmentIDs: nil,
// SlicedBlob: nil,
// SlicedNumCount: 1,
// SlicedOffset: 0,
// }
// resultData := constructSearchResulstData()
// sliceBlob, err := proto.Marshal(resultData)
// assert.NoError(t, err)
// result1.SlicedBlob = sliceBlob
//
// // result2.SliceBlob = nil, will be skipped in decode stage
// result2 := &internalpb.SearchResults{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_SearchResult,
// MsgID: 0,
// Timestamp: 0,
// SourceID: 0,
// },
// Status: &commonpb.Status{
// ErrorCode: commonpb.ErrorCode_Success,
// Reason: "",
// },
// ResultChannelID: "",
// MetricType: distance.L2,
// NumQueries: int64(nq),
// TopK: int64(topk),
// SealedSegmentIDsSearched: nil,
// ChannelIDsSearched: nil,
// GlobalSealedSegmentIDs: nil,
// SlicedBlob: nil,
// SlicedNumCount: 1,
// SlicedOffset: 0,
// }
//
// // send search result
// task.resultBuf <- result1
// task.resultBuf <- result2
// }
// }
// }
// }()
//
// assert.NoError(t, task.OnEnqueue())
// assert.Error(t, task.PreExecute(ctx))
//
// cancel()
// wg.Wait()
}
func TestSearchTaskV2_all(t *testing.T) {
// var err error
//
// Params.Init()
// Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
//
// rc := NewRootCoordMock()
// rc.Start()
// defer rc.Stop()
//
// ctx := context.Background()
//
// err = InitMetaCache(rc)
// assert.NoError(t, err)
//
// shardsNum := int32(2)
// prefix := "TestSearchTaskV2_all"
// collectionName := prefix + funcutil.GenRandomStr()
//
// dim := 128
// expr := fmt.Sprintf("%s > 0", testInt64Field)
// nq := 10
// topk := 10
// roundDecimal := 3
// nprobe := 10
//
// fieldName2Types := map[string]schemapb.DataType{
// testBoolField: schemapb.DataType_Bool,
// testInt32Field: schemapb.DataType_Int32,
// testInt64Field: schemapb.DataType_Int64,
// testFloatField: schemapb.DataType_Float,
// testDoubleField: schemapb.DataType_Double,
// testFloatVecField: schemapb.DataType_FloatVector,
// }
// if enableMultipleVectorFields {
// fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
// }
//
// schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
// marshaledSchema, err := proto.Marshal(schema)
// assert.NoError(t, err)
//
// createColT := &createCollectionTask{
// Condition: NewTaskCondition(ctx),
// CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
// Base: nil,
// CollectionName: collectionName,
// Schema: marshaledSchema,
// ShardsNum: shardsNum,
// },
// ctx: ctx,
// rootCoord: rc,
// result: nil,
// schema: nil,
// }
//
// assert.NoError(t, createColT.OnEnqueue())
// assert.NoError(t, createColT.PreExecute(ctx))
// assert.NoError(t, createColT.Execute(ctx))
// assert.NoError(t, createColT.PostExecute(ctx))
//
// dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
// query := newMockGetChannelsService()
// factory := newSimpleMockMsgStreamFactory()
//
// collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
// assert.NoError(t, err)
//
// qc := NewQueryCoordMock()
// qc.Start()
// defer qc.Stop()
// status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_LoadCollection,
// MsgID: 0,
// Timestamp: 0,
// SourceID: Params.ProxyCfg.GetNodeID(),
// },
// DbID: 0,
// CollectionID: collectionID,
// Schema: nil,
// })
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
//
// req := constructSearchRequest("", collectionName,
// expr,
// testFloatVecField,
// nq, dim, nprobe, topk, roundDecimal)
//
// task := &searchTaskV2{
// Condition: NewTaskCondition(ctx),
// SearchRequest: &internalpb.SearchRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_Search,
// MsgID: 0,
// Timestamp: 0,
// SourceID: Params.ProxyCfg.GetNodeID(),
// },
// ResultChannelID: strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
// DbID: 0,
// CollectionID: 0,
// PartitionIDs: nil,
// Dsl: "",
// PlaceholderGroup: nil,
// DslType: 0,
// SerializedExprPlan: nil,
// OutputFieldsId: nil,
// TravelTimestamp: 0,
// GuaranteeTimestamp: 0,
// },
// ctx: ctx,
// resultBuf: make(chan *internalpb.SearchResults, 10),
// result: nil,
// request: req,
// qc: qc,
// tr: timerecord.NewTimeRecorder("search"),
// }
//
// // simple mock for query node
// // TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
//
// var wg sync.WaitGroup
// wg.Add(1)
// consumeCtx, cancel := context.WithCancel(ctx)
// go func() {
// defer wg.Done()
// for {
// select {
// case <-consumeCtx.Done():
// return
// case pack, ok := <-stream.Chan():
// assert.True(t, ok)
// if pack == nil {
// continue
// }
//
// for _, msg := range pack.Msgs {
// _, ok := msg.(*msgstream.SearchMsg)
// assert.True(t, ok)
// // TODO(dragondriver): construct result according to the request
//
// constructSearchResulstData := func() *schemapb.SearchResultData {
// resultData := &schemapb.SearchResultData{
// NumQueries: int64(nq),
// TopK: int64(topk),
// Scores: make([]float32, nq*topk),
// Ids: &schemapb.IDs{
// IdField: &schemapb.IDs_IntId{
// IntId: &schemapb.LongArray{
// Data: make([]int64, nq*topk),
// },
// },
// },
// Topks: make([]int64, nq),
// }
//
// fieldID := common.StartOfUserFieldID
// for fieldName, dataType := range fieldName2Types {
// resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk))
// fieldID++
// }
//
// for i := 0; i < nq; i++ {
// for j := 0; j < topk; j++ {
// offset := i*topk + j
// score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
// id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
// resultData.Scores[offset] = score
// resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
// }
// resultData.Topks[i] = int64(topk)
// }
//
// return resultData
// }
//
// result1 := &internalpb.SearchResults{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_SearchResult,
// MsgID: 0,
// Timestamp: 0,
// SourceID: 0,
// },
// Status: &commonpb.Status{
// ErrorCode: commonpb.ErrorCode_Success,
// Reason: "",
// },
// ResultChannelID: "",
// MetricType: distance.L2,
// NumQueries: int64(nq),
// TopK: int64(topk),
// SealedSegmentIDsSearched: nil,
// ChannelIDsSearched: nil,
// GlobalSealedSegmentIDs: nil,
// SlicedBlob: nil,
// SlicedNumCount: 1,
// SlicedOffset: 0,
// }
// resultData := constructSearchResulstData()
// sliceBlob, err := proto.Marshal(resultData)
// assert.NoError(t, err)
// result1.SlicedBlob = sliceBlob
//
// // result2.SliceBlob = nil, will be skipped in decode stage
// result2 := &internalpb.SearchResults{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_SearchResult,
// MsgID: 0,
// Timestamp: 0,
// SourceID: 0,
// },
// Status: &commonpb.Status{
// ErrorCode: commonpb.ErrorCode_Success,
// Reason: "",
// },
// ResultChannelID: "",
// MetricType: distance.L2,
// NumQueries: int64(nq),
// TopK: int64(topk),
// SealedSegmentIDsSearched: nil,
// ChannelIDsSearched: nil,
// GlobalSealedSegmentIDs: nil,
// SlicedBlob: nil,
// SlicedNumCount: 1,
// SlicedOffset: 0,
// }
//
// // send search result
// task.resultBuf <- result1
// task.resultBuf <- result2
// }
// }
// }
// }()
//
// assert.NoError(t, task.OnEnqueue())
// assert.NoError(t, task.PreExecute(ctx))
// assert.NoError(t, task.Execute(ctx))
// assert.NoError(t, task.PostExecute(ctx))
//
// cancel()
// wg.Wait()
}
func TestSearchTaskV2_7803_reduce(t *testing.T) {
// var err error
//
// Params.Init()
// Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
//
// rc := NewRootCoordMock()
// rc.Start()
// defer rc.Stop()
//
// ctx := context.Background()
//
// err = InitMetaCache(rc)
// assert.NoError(t, err)
//
// shardsNum := int32(2)
// prefix := "TestSearchTaskV2_7803_reduce"
// collectionName := prefix + funcutil.GenRandomStr()
// int64Field := "int64"
// floatVecField := "fvec"
// dim := 128
// expr := fmt.Sprintf("%s > 0", int64Field)
// nq := 10
// topk := 10
// roundDecimal := 3
// nprobe := 10
//
// schema := constructCollectionSchema(
// int64Field,
// floatVecField,
// dim,
// collectionName)
// marshaledSchema, err := proto.Marshal(schema)
// assert.NoError(t, err)
//
// createColT := &createCollectionTask{
// Condition: NewTaskCondition(ctx),
// CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
// Base: nil,
// CollectionName: collectionName,
// Schema: marshaledSchema,
// ShardsNum: shardsNum,
// },
// ctx: ctx,
// rootCoord: rc,
// result: nil,
// schema: nil,
// }
//
// assert.NoError(t, createColT.OnEnqueue())
// assert.NoError(t, createColT.PreExecute(ctx))
// assert.NoError(t, createColT.Execute(ctx))
// assert.NoError(t, createColT.PostExecute(ctx))
//
// dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
// query := newMockGetChannelsService()
// factory := newSimpleMockMsgStreamFactory()
//
// collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
// assert.NoError(t, err)
//
// qc := NewQueryCoordMock()
// qc.Start()
// defer qc.Stop()
// status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_LoadCollection,
// MsgID: 0,
// Timestamp: 0,
// SourceID: Params.ProxyCfg.GetNodeID(),
// },
// DbID: 0,
// CollectionID: collectionID,
// Schema: nil,
// })
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
//
// req := constructSearchRequest("", collectionName,
// expr,
// floatVecField,
// nq, dim, nprobe, topk, roundDecimal)
//
// task := &searchTaskV2{
// Condition: NewTaskCondition(ctx),
// SearchRequest: &internalpb.SearchRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_Search,
// MsgID: 0,
// Timestamp: 0,
// SourceID: Params.ProxyCfg.GetNodeID(),
// },
// ResultChannelID: strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
// DbID: 0,
// CollectionID: 0,
// PartitionIDs: nil,
// Dsl: "",
// PlaceholderGroup: nil,
// DslType: 0,
// SerializedExprPlan: nil,
// OutputFieldsId: nil,
// TravelTimestamp: 0,
// GuaranteeTimestamp: 0,
// },
// ctx: ctx,
// resultBuf: make(chan *internalpb.SearchResults, 10),
// result: nil,
// request: req,
// qc: qc,
// tr: timerecord.NewTimeRecorder("search"),
// }
//
// // simple mock for query node
// // TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
//
// var wg sync.WaitGroup
// wg.Add(1)
// consumeCtx, cancel := context.WithCancel(ctx)
// go func() {
// defer wg.Done()
// for {
// select {
// case <-consumeCtx.Done():
// return
// case pack, ok := <-stream.Chan():
// assert.True(t, ok)
// if pack == nil {
// continue
// }
//
// for _, msg := range pack.Msgs {
// _, ok := msg.(*msgstream.SearchMsg)
// assert.True(t, ok)
// // TODO(dragondriver): construct result according to the request
//
// constructSearchResulstData := func(invalidNum int) *schemapb.SearchResultData {
// resultData := &schemapb.SearchResultData{
// NumQueries: int64(nq),
// TopK: int64(topk),
// FieldsData: nil,
// Scores: make([]float32, nq*topk),
// Ids: &schemapb.IDs{
// IdField: &schemapb.IDs_IntId{
// IntId: &schemapb.LongArray{
// Data: make([]int64, nq*topk),
// },
// },
// },
// Topks: make([]int64, nq),
// }
//
// for i := 0; i < nq; i++ {
// for j := 0; j < topk; j++ {
// offset := i*topk + j
// if j >= invalidNum {
// resultData.Scores[offset] = minFloat32
// resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = -1
// } else {
// score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly
// id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
// resultData.Scores[offset] = score
// resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
// }
// }
// resultData.Topks[i] = int64(topk)
// }
//
// return resultData
// }
//
// result1 := &internalpb.SearchResults{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_SearchResult,
// MsgID: 0,
// Timestamp: 0,
// SourceID: 0,
// },
// Status: &commonpb.Status{
// ErrorCode: commonpb.ErrorCode_Success,
// Reason: "",
// },
// ResultChannelID: "",
// MetricType: distance.L2,
// NumQueries: int64(nq),
// TopK: int64(topk),
// SealedSegmentIDsSearched: nil,
// ChannelIDsSearched: nil,
// GlobalSealedSegmentIDs: nil,
// SlicedBlob: nil,
// SlicedNumCount: 1,
// SlicedOffset: 0,
// }
// resultData := constructSearchResulstData(topk / 2)
// sliceBlob, err := proto.Marshal(resultData)
// assert.NoError(t, err)
// result1.SlicedBlob = sliceBlob
//
// result2 := &internalpb.SearchResults{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_SearchResult,
// MsgID: 0,
// Timestamp: 0,
// SourceID: 0,
// },
// Status: &commonpb.Status{
// ErrorCode: commonpb.ErrorCode_Success,
// Reason: "",
// },
// ResultChannelID: "",
// MetricType: distance.L2,
// NumQueries: int64(nq),
// TopK: int64(topk),
// SealedSegmentIDsSearched: nil,
// ChannelIDsSearched: nil,
// GlobalSealedSegmentIDs: nil,
// SlicedBlob: nil,
// SlicedNumCount: 1,
// SlicedOffset: 0,
// }
// resultData2 := constructSearchResulstData(topk - topk/2)
// sliceBlob2, err := proto.Marshal(resultData2)
// assert.NoError(t, err)
// result2.SlicedBlob = sliceBlob2
//
// // send search result
// task.resultBuf <- result1
// task.resultBuf <- result2
// }
// }
// }
// }()
//
// assert.NoError(t, task.OnEnqueue())
// assert.NoError(t, task.PreExecute(ctx))
// assert.NoError(t, task.Execute(ctx))
// assert.NoError(t, task.PostExecute(ctx))
//
// cancel()
// wg.Wait()
}
func Test_checkSearchResultData(t *testing.T) {
type args struct {
data *schemapb.SearchResultData
nq int64
topk int64
}
tests := []struct {
name string
args args
wantErr bool
}{
{
args: args{
data: &schemapb.SearchResultData{NumQueries: 100},
nq: 10,
},
wantErr: true,
},
{
args: args{
data: &schemapb.SearchResultData{NumQueries: 1, TopK: 1},
nq: 1,
topk: 10,
},
wantErr: true,
},
{
args: args{
data: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2}, // != nq * topk
},
},
},
},
nq: 1,
topk: 1,
},
wantErr: true,
},
{
args: args{
data: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"1", "2"}, // != nq * topk
},
},
},
},
nq: 1,
topk: 1,
},
wantErr: true,
},
{
args: args{
data: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1},
},
},
},
Scores: []float32{0.99, 0.98}, // != nq * topk
},
nq: 1,
topk: 1,
},
wantErr: true,
},
{
args: args{
data: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1},
},
},
},
Scores: []float32{0.99},
},
nq: 1,
topk: 1,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := checkSearchResultData(tt.args.data, tt.args.nq, tt.args.topk); (err != nil) != tt.wantErr {
t.Errorf("checkSearchResultData(%v, %v, %v) error = %v, wantErr %v",
tt.args.data, tt.args.nq, tt.args.topk, err, tt.wantErr)
}
})
}
}
func Test_selectSearchResultData_int(t *testing.T) {
type args struct {
dataArray []*schemapb.SearchResultData
resultOffsets [][]int64
offsets []int64
topk int64
nq int64
qi int64
}
tests := []struct {
name string
args args
want int
}{
{
args: args{
dataArray: []*schemapb.SearchResultData{
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 7, 5, 3, 1},
},
},
},
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
},
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{12, 10, 8, 6, 4, 2},
},
},
},
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
},
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
offsets: []int64{0, 1},
topk: 2,
nq: 3,
qi: 0,
},
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
}
})
}
}
func Test_selectSearchResultData_str(t *testing.T) {
type args struct {
dataArray []*schemapb.SearchResultData
resultOffsets [][]int64
offsets []int64
topk int64
nq int64
qi int64
}
tests := []struct {
name string
args args
want int
}{
{
args: args{
dataArray: []*schemapb.SearchResultData{
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"11", "9", "7", "5", "3", "1"},
},
},
},
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
},
{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"12", "10", "8", "6", "4", "2"},
},
},
},
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
},
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
offsets: []int64{0, 1},
topk: 2,
nq: 3,
qi: 1,
},
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
}
})
}
}
func Test_reduceSearchResultData_int(t *testing.T) {
topk := 2
nq := 3
results := []*schemapb.SearchResultData{
{
NumQueries: int64(nq),
TopK: int64(topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 7, 5, 3, 1},
},
},
},
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
},
{
NumQueries: int64(nq),
TopK: int64(topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{12, 10, 8, 6, 4, 2},
},
},
},
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
}
reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64)
assert.NoError(t, err)
assert.ElementsMatch(t, []int64{3, 4, 7, 8, 11, 12}, reduced.GetResults().GetIds().GetIntId().GetData())
// hard to compare floating point value.
// TODO: compare scores.
}
func Test_reduceSearchResultData_str(t *testing.T) {
topk := 2
nq := 3
results := []*schemapb.SearchResultData{
{
NumQueries: int64(nq),
TopK: int64(topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"11", "9", "7", "5", "3", "1"},
},
},
},
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
},
{
NumQueries: int64(nq),
TopK: int64(topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"12", "10", "8", "6", "4", "2"},
},
},
},
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
Topks: []int64{2, 2, 2},
},
}
reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar)
assert.NoError(t, err)
assert.ElementsMatch(t, []string{"3", "4", "7", "8", "11", "12"}, reduced.GetResults().GetIds().GetStrId().GetData())
// hard to compare floating point value.
// TODO: compare scores.
}