mirror of https://github.com/milvus-io/milvus.git
783 lines
23 KiB
Go
783 lines
23 KiB
Go
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
|
|
|
package querynode
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"math"
|
|
"math/rand"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/stretchr/testify/assert"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus/internal/common"
|
|
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/msgstream"
|
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
|
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
|
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
)
|
|
|
|
func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*queryCollection, error) {
|
|
tSafe := newTSafeReplica()
|
|
historical, err := genSimpleHistorical(ctx, tSafe)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
streaming, err := genSimpleStreaming(ctx, tSafe)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fac, err := genFactory()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
localCM, err := genLocalChunkManager()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
remoteCM, err := genRemoteChunkManager(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
queryCollection, err := newQueryCollection(ctx, cancel,
|
|
defaultCollectionID,
|
|
historical,
|
|
streaming,
|
|
fac,
|
|
localCM,
|
|
remoteCM,
|
|
false)
|
|
return queryCollection, err
|
|
}
|
|
|
|
func genSimpleSegmentInfo() *querypb.SegmentInfo {
|
|
return &querypb.SegmentInfo{
|
|
SegmentID: defaultSegmentID,
|
|
CollectionID: defaultCollectionID,
|
|
PartitionID: defaultPartitionID,
|
|
}
|
|
}
|
|
|
|
func genSimpleSealedSegmentsChangeInfo() *querypb.SealedSegmentsChangeInfo {
|
|
changeInfo := &querypb.SegmentChangeInfo{
|
|
OnlineNodeID: Params.QueryNodeID,
|
|
OnlineSegments: []*querypb.SegmentInfo{},
|
|
OfflineNodeID: Params.QueryNodeID,
|
|
OfflineSegments: []*querypb.SegmentInfo{},
|
|
}
|
|
return &querypb.SealedSegmentsChangeInfo{
|
|
Base: genCommonMsgBase(commonpb.MsgType_SealedSegmentsChangeInfo),
|
|
Infos: []*querypb.SegmentChangeInfo{changeInfo},
|
|
}
|
|
}
|
|
|
|
func genSimpleSealedSegmentsChangeInfoMsg() *msgstream.SealedSegmentsChangeInfoMsg {
|
|
return &msgstream.SealedSegmentsChangeInfoMsg{
|
|
BaseMsg: genMsgStreamBaseMsg(),
|
|
SealedSegmentsChangeInfo: *genSimpleSealedSegmentsChangeInfo(),
|
|
}
|
|
}
|
|
|
|
func updateTSafe(queryCollection *queryCollection, timestamp Timestamp) {
|
|
// register
|
|
queryCollection.tSafeWatchers[defaultVChannel] = newTSafeWatcher()
|
|
queryCollection.tSafeWatchers[defaultHistoricalVChannel] = newTSafeWatcher()
|
|
queryCollection.streaming.tSafeReplica.addTSafe(defaultVChannel)
|
|
queryCollection.streaming.tSafeReplica.registerTSafeWatcher(defaultVChannel, queryCollection.tSafeWatchers[defaultVChannel])
|
|
queryCollection.historical.tSafeReplica.addTSafe(defaultHistoricalVChannel)
|
|
queryCollection.historical.tSafeReplica.registerTSafeWatcher(defaultHistoricalVChannel, queryCollection.tSafeWatchers[defaultHistoricalVChannel])
|
|
queryCollection.addTSafeWatcher(defaultVChannel)
|
|
queryCollection.addTSafeWatcher(defaultHistoricalVChannel)
|
|
|
|
queryCollection.streaming.tSafeReplica.setTSafe(defaultVChannel, defaultCollectionID, timestamp)
|
|
queryCollection.historical.tSafeReplica.setTSafe(defaultHistoricalVChannel, defaultCollectionID, timestamp)
|
|
}
|
|
|
|
func TestQueryCollection_withoutVChannel(t *testing.T) {
|
|
ctx := context.Background()
|
|
m := map[string]interface{}{
|
|
"PulsarAddress": Params.PulsarAddress,
|
|
"ReceiveBufSize": 1024,
|
|
"PulsarBufSize": 1024}
|
|
factory := msgstream.NewPmsFactory()
|
|
err := factory.SetParams(m)
|
|
assert.Nil(t, err)
|
|
etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath)
|
|
assert.Nil(t, err)
|
|
|
|
schema := genTestCollectionSchema(0, false, 2)
|
|
historicalReplica := newCollectionReplica(etcdKV)
|
|
tsReplica := newTSafeReplica()
|
|
streamingReplica := newCollectionReplica(etcdKV)
|
|
historical := newHistorical(context.Background(), historicalReplica, etcdKV, tsReplica)
|
|
|
|
//add a segment to historical data
|
|
err = historical.replica.addCollection(0, schema)
|
|
assert.Nil(t, err)
|
|
err = historical.replica.addPartition(0, 1)
|
|
assert.Nil(t, err)
|
|
err = historical.replica.addSegment(2, 1, 0, "testChannel", segmentTypeSealed, true)
|
|
assert.Nil(t, err)
|
|
segment, err := historical.replica.getSegmentByID(2)
|
|
assert.Nil(t, err)
|
|
const N = 2
|
|
rowID := []int32{1, 2}
|
|
timeStamp := []int64{0, 1}
|
|
age := []int64{10, 20}
|
|
vectorData := []float32{1, 2, 3, 4}
|
|
err = segment.segmentLoadFieldData(0, N, rowID)
|
|
assert.Nil(t, err)
|
|
err = segment.segmentLoadFieldData(1, N, timeStamp)
|
|
assert.Nil(t, err)
|
|
err = segment.segmentLoadFieldData(101, N, age)
|
|
assert.Nil(t, err)
|
|
err = segment.segmentLoadFieldData(100, N, vectorData)
|
|
assert.Nil(t, err)
|
|
|
|
//create a streaming
|
|
streaming := newStreaming(ctx, streamingReplica, factory, etcdKV, tsReplica)
|
|
err = streaming.replica.addCollection(0, schema)
|
|
assert.Nil(t, err)
|
|
err = streaming.replica.addPartition(0, 1)
|
|
assert.Nil(t, err)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
queryCollection, err := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false)
|
|
assert.NoError(t, err)
|
|
|
|
producerChannels := []string{"testResultChannel"}
|
|
queryCollection.queryResultMsgStream.AsProducer(producerChannels)
|
|
|
|
dim := 2
|
|
// generate search rawData
|
|
var vec = make([]float32, dim)
|
|
for i := 0; i < dim; i++ {
|
|
vec[i] = rand.Float32()
|
|
}
|
|
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
|
|
var searchRawData1 []byte
|
|
var searchRawData2 []byte
|
|
for i, ele := range vec {
|
|
buf := make([]byte, 4)
|
|
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
|
searchRawData1 = append(searchRawData1, buf...)
|
|
}
|
|
for i, ele := range vec {
|
|
buf := make([]byte, 4)
|
|
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
|
|
searchRawData2 = append(searchRawData2, buf...)
|
|
}
|
|
|
|
// generate placeholder
|
|
placeholderValue := milvuspb.PlaceholderValue{
|
|
Tag: "$0",
|
|
Type: milvuspb.PlaceholderType_FloatVector,
|
|
Values: [][]byte{searchRawData1, searchRawData2},
|
|
}
|
|
placeholderGroup := milvuspb.PlaceholderGroup{
|
|
Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue},
|
|
}
|
|
placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
|
assert.Nil(t, err)
|
|
|
|
queryMsg := &msgstream.SearchMsg{
|
|
BaseMsg: msgstream.BaseMsg{
|
|
Ctx: ctx,
|
|
BeginTimestamp: 10,
|
|
EndTimestamp: 10,
|
|
},
|
|
SearchRequest: internalpb.SearchRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Search,
|
|
MsgID: 1,
|
|
Timestamp: Timestamp(10),
|
|
SourceID: 1,
|
|
},
|
|
CollectionID: 0,
|
|
ResultChannelID: "testResultChannel",
|
|
Dsl: dslString,
|
|
PlaceholderGroup: placeGroupByte,
|
|
TravelTimestamp: 10,
|
|
GuaranteeTimestamp: 10,
|
|
},
|
|
}
|
|
err = queryCollection.receiveQueryMsg(queryMsg)
|
|
assert.Nil(t, err)
|
|
|
|
queryCollection.cancel()
|
|
queryCollection.close()
|
|
historical.close()
|
|
streaming.close()
|
|
}
|
|
|
|
func TestQueryCollection_unsolvedMsg(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
qm, err := genSimpleSearchMsg()
|
|
assert.NoError(t, err)
|
|
|
|
queryCollection.addToUnsolvedMsg(qm)
|
|
|
|
res := queryCollection.popAllUnsolvedMsg()
|
|
assert.NotNil(t, res)
|
|
assert.Len(t, res, 1)
|
|
}
|
|
|
|
func TestQueryCollection_consumeQuery(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
runConsumeQuery := func(msg msgstream.TsMsg) {
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
queryChannel := genQueryChannel()
|
|
queryCollection.queryMsgStream.AsConsumer([]Channel{queryChannel}, defaultSubName)
|
|
queryCollection.queryMsgStream.Start()
|
|
|
|
go queryCollection.consumeQuery()
|
|
|
|
producer, err := genQueryMsgStream(ctx)
|
|
assert.NoError(t, err)
|
|
producer.AsProducer([]Channel{queryChannel})
|
|
producer.Start()
|
|
msgPack := &msgstream.MsgPack{
|
|
BeginTs: 0,
|
|
EndTs: 10,
|
|
Msgs: []msgstream.TsMsg{msg},
|
|
}
|
|
err = producer.Produce(msgPack)
|
|
assert.NoError(t, err)
|
|
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
|
|
t.Run("consume search", func(t *testing.T) {
|
|
msg, err := genSimpleSearchMsg()
|
|
assert.NoError(t, err)
|
|
runConsumeQuery(msg)
|
|
})
|
|
|
|
t.Run("consume retrieve", func(t *testing.T) {
|
|
msg, err := genSimpleRetrieveMsg()
|
|
assert.NoError(t, err)
|
|
runConsumeQuery(msg)
|
|
})
|
|
|
|
t.Run("consume SimpleSealedSegmentsChangeInfoMsg", func(t *testing.T) {
|
|
// test is success if it doesn't block
|
|
msg := genSimpleSealedSegmentsChangeInfoMsg()
|
|
simpleInfo := genSimpleSegmentInfo()
|
|
simpleInfo.CollectionID = 1000
|
|
msg.Infos[0].OnlineSegments = append(msg.Infos[0].OnlineSegments, simpleInfo)
|
|
runConsumeQuery(msg)
|
|
})
|
|
|
|
t.Run("consume invalid msg", func(t *testing.T) {
|
|
msg, err := genSimpleRetrieveMsg()
|
|
assert.NoError(t, err)
|
|
msg.Base.MsgType = commonpb.MsgType_CreateCollection
|
|
runConsumeQuery(msg)
|
|
})
|
|
}
|
|
|
|
func TestQueryCollection_TranslateHits(t *testing.T) {
|
|
fieldID := FieldID(0)
|
|
fieldIDs := []FieldID{fieldID}
|
|
|
|
genRawHits := func(dataType schemapb.DataType) [][]byte {
|
|
// ids
|
|
ids := make([]int64, 0)
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
ids = append(ids, int64(i))
|
|
}
|
|
|
|
// raw data
|
|
rawData := make([][]byte, 0)
|
|
switch dataType {
|
|
case schemapb.DataType_Bool:
|
|
var buf bytes.Buffer
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
err := binary.Write(&buf, common.Endian, true)
|
|
assert.NoError(t, err)
|
|
}
|
|
rawData = append(rawData, buf.Bytes())
|
|
case schemapb.DataType_Int8:
|
|
var buf bytes.Buffer
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
err := binary.Write(&buf, common.Endian, int8(i))
|
|
assert.NoError(t, err)
|
|
}
|
|
rawData = append(rawData, buf.Bytes())
|
|
case schemapb.DataType_Int16:
|
|
var buf bytes.Buffer
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
err := binary.Write(&buf, common.Endian, int16(i))
|
|
assert.NoError(t, err)
|
|
}
|
|
rawData = append(rawData, buf.Bytes())
|
|
case schemapb.DataType_Int32:
|
|
var buf bytes.Buffer
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
err := binary.Write(&buf, common.Endian, int32(i))
|
|
assert.NoError(t, err)
|
|
}
|
|
rawData = append(rawData, buf.Bytes())
|
|
case schemapb.DataType_Int64:
|
|
var buf bytes.Buffer
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
err := binary.Write(&buf, common.Endian, int64(i))
|
|
assert.NoError(t, err)
|
|
}
|
|
rawData = append(rawData, buf.Bytes())
|
|
case schemapb.DataType_Float:
|
|
var buf bytes.Buffer
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
err := binary.Write(&buf, common.Endian, float32(i))
|
|
assert.NoError(t, err)
|
|
}
|
|
rawData = append(rawData, buf.Bytes())
|
|
case schemapb.DataType_Double:
|
|
var buf bytes.Buffer
|
|
for i := 0; i < defaultMsgLength; i++ {
|
|
err := binary.Write(&buf, common.Endian, float64(i))
|
|
assert.NoError(t, err)
|
|
}
|
|
rawData = append(rawData, buf.Bytes())
|
|
}
|
|
hit := &milvuspb.Hits{
|
|
IDs: ids,
|
|
RowData: rawData,
|
|
}
|
|
hits := []*milvuspb.Hits{hit}
|
|
rawHits := make([][]byte, 0)
|
|
for _, h := range hits {
|
|
rawHit, err := proto.Marshal(h)
|
|
assert.NoError(t, err)
|
|
rawHits = append(rawHits, rawHit)
|
|
}
|
|
return rawHits
|
|
}
|
|
|
|
genSchema := func(dataType schemapb.DataType) *typeutil.SchemaHelper {
|
|
schema := &schemapb.CollectionSchema{
|
|
Name: defaultCollectionName,
|
|
AutoID: true,
|
|
Fields: []*schemapb.FieldSchema{
|
|
genConstantField(constFieldParam{
|
|
id: fieldID,
|
|
dataType: dataType,
|
|
}),
|
|
},
|
|
}
|
|
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
|
assert.NoError(t, err)
|
|
return schemaHelper
|
|
}
|
|
|
|
t.Run("test bool field", func(t *testing.T) {
|
|
dataType := schemapb.DataType_Bool
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test int8 field", func(t *testing.T) {
|
|
dataType := schemapb.DataType_Int8
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test int16 field", func(t *testing.T) {
|
|
dataType := schemapb.DataType_Int16
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test int32 field", func(t *testing.T) {
|
|
dataType := schemapb.DataType_Int32
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test int64 field", func(t *testing.T) {
|
|
dataType := schemapb.DataType_Int64
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test float field", func(t *testing.T) {
|
|
dataType := schemapb.DataType_Float
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test double field", func(t *testing.T) {
|
|
dataType := schemapb.DataType_Double
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test field with error type", func(t *testing.T) {
|
|
dataType := schemapb.DataType_FloatVector
|
|
_, err := translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.Error(t, err)
|
|
|
|
dataType = schemapb.DataType_BinaryVector
|
|
_, err = translateHits(genSchema(dataType), fieldIDs, genRawHits(dataType))
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestQueryCollection_serviceableTime(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
st := Timestamp(1000)
|
|
queryCollection.setServiceableTime(st)
|
|
|
|
gracefulTimeInMilliSecond := Params.GracefulTime
|
|
gracefulTime := tsoutil.ComposeTS(gracefulTimeInMilliSecond, 0)
|
|
resST := queryCollection.getServiceableTime()
|
|
assert.Equal(t, st+gracefulTime, resST)
|
|
}
|
|
|
|
func TestQueryCollection_addTSafeWatcher(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
queryCollection.addTSafeWatcher(defaultVChannel)
|
|
}
|
|
|
|
func TestQueryCollection_waitNewTSafe(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
timestamp := Timestamp(1000)
|
|
updateTSafe(queryCollection, timestamp)
|
|
|
|
resTimestamp, err := queryCollection.waitNewTSafe()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, timestamp, resTimestamp)
|
|
}
|
|
|
|
func TestQueryCollection_mergeRetrieveResults(t *testing.T) {
|
|
const (
|
|
Dim = 8
|
|
Int64FieldName = "Int64Field"
|
|
FloatVectorFieldName = "FloatVectorField"
|
|
Int64FieldID = common.StartOfUserFieldID + 1
|
|
FloatVectorFieldID = common.StartOfUserFieldID + 2
|
|
)
|
|
Int64Array := []int64{11, 22}
|
|
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
|
|
|
|
var fieldDataArray1 []*schemapb.FieldData
|
|
fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
|
|
fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
|
|
|
|
var fieldDataArray2 []*schemapb.FieldData
|
|
fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
|
|
fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
|
|
|
|
result1 := &segcorepb.RetrieveResults{
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{0, 1},
|
|
},
|
|
},
|
|
},
|
|
Offset: []int64{0, 1},
|
|
FieldsData: fieldDataArray1,
|
|
}
|
|
result2 := &segcorepb.RetrieveResults{
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{0, 1},
|
|
},
|
|
},
|
|
},
|
|
Offset: []int64{0, 1},
|
|
FieldsData: fieldDataArray2,
|
|
}
|
|
|
|
result, err := mergeRetrieveResults([]*segcorepb.RetrieveResults{result1, result2})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.FieldsData[0].GetScalars().GetLongData().Data))
|
|
assert.Equal(t, 2*Dim, len(result.FieldsData[1].GetVectors().GetFloatVector().Data))
|
|
|
|
_, err = mergeRetrieveResults(nil)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestQueryCollection_doUnsolvedQueryMsg(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
timestamp := Timestamp(1000)
|
|
updateTSafe(queryCollection, timestamp)
|
|
|
|
go queryCollection.doUnsolvedQueryMsg()
|
|
|
|
msg, err := genSimpleSearchMsg()
|
|
assert.NoError(t, err)
|
|
queryCollection.addToUnsolvedMsg(msg)
|
|
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
|
|
func TestQueryCollection_search(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
queryChannel := genQueryChannel()
|
|
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
|
queryCollection.queryResultMsgStream.Start()
|
|
|
|
err = queryCollection.streaming.replica.removeSegment(defaultSegmentID)
|
|
assert.NoError(t, err)
|
|
|
|
err = queryCollection.historical.replica.removeSegment(defaultSegmentID)
|
|
assert.NoError(t, err)
|
|
|
|
msg, err := genSimpleSearchMsg()
|
|
assert.NoError(t, err)
|
|
|
|
err = queryCollection.search(msg)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestQueryCollection_receive(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
queryChannel := genQueryChannel()
|
|
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
|
queryCollection.queryResultMsgStream.Start()
|
|
|
|
vecCM, err := genVectorChunkManager(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
queryCollection.vectorChunkManager = vecCM
|
|
|
|
err = queryCollection.streaming.replica.removeSegment(defaultSegmentID)
|
|
assert.NoError(t, err)
|
|
|
|
err = queryCollection.historical.replica.removeSegment(defaultSegmentID)
|
|
assert.NoError(t, err)
|
|
|
|
msg, err := genSimpleRetrieveMsg()
|
|
assert.NoError(t, err)
|
|
|
|
err = queryCollection.retrieve(msg)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestQueryCollection_AddPopUnsolvedMsg(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
qCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.Nil(t, err)
|
|
var i int64
|
|
for i = 0; i < 3; i++ {
|
|
qCollection.addToUnsolvedMsg(&msgstream.RetrieveMsg{
|
|
RetrieveRequest: internalpb.RetrieveRequest{
|
|
Base: &commonpb.MsgBase{MsgID: i},
|
|
},
|
|
})
|
|
}
|
|
|
|
unsolved := qCollection.popAllUnsolvedMsg()
|
|
assert.EqualValues(t, 3, len(unsolved))
|
|
for i := 0; i < 3; i++ {
|
|
assert.EqualValues(t, i, unsolved[i].ID())
|
|
}
|
|
|
|
// add new msg to unsolved msgs and check old unsolved msg
|
|
for i := 0; i < 3; i++ {
|
|
qCollection.addToUnsolvedMsg(&msgstream.RetrieveMsg{
|
|
RetrieveRequest: internalpb.RetrieveRequest{
|
|
Base: &commonpb.MsgBase{MsgID: 4},
|
|
},
|
|
})
|
|
}
|
|
|
|
for i := 0; i < 3; i++ {
|
|
assert.EqualValues(t, i, unsolved[i].ID())
|
|
}
|
|
}
|
|
|
|
func TestQueryCollection_adjustByChangeInfo(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
|
|
t.Run("test adjustByChangeInfo", func(t *testing.T) {
|
|
qc, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.Nil(t, err)
|
|
|
|
segmentChangeInfos := genSimpleSealedSegmentsChangeInfoMsg()
|
|
|
|
// test online
|
|
segmentChangeInfos.Infos[0].OnlineSegments = append(segmentChangeInfos.Infos[0].OnlineSegments, genSimpleSegmentInfo())
|
|
err = qc.adjustByChangeInfo(segmentChangeInfos)
|
|
assert.NoError(t, err)
|
|
ids := qc.globalSegmentManager.getGlobalSegmentIDs()
|
|
assert.Len(t, ids, 1)
|
|
|
|
// test offline
|
|
segmentChangeInfos.Infos[0].OnlineSegments = make([]*querypb.SegmentInfo, 0)
|
|
segmentChangeInfos.Infos[0].OfflineSegments = append(segmentChangeInfos.Infos[0].OfflineSegments, genSimpleSegmentInfo())
|
|
err = qc.adjustByChangeInfo(segmentChangeInfos)
|
|
assert.NoError(t, err)
|
|
ids = qc.globalSegmentManager.getGlobalSegmentIDs()
|
|
assert.Len(t, ids, 0)
|
|
})
|
|
|
|
t.Run("test mismatch collectionID when adjustByChangeInfo", func(t *testing.T) {
|
|
qc, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.Nil(t, err)
|
|
|
|
segmentChangeInfos := genSimpleSealedSegmentsChangeInfoMsg()
|
|
|
|
// test online
|
|
simpleInfo := genSimpleSegmentInfo()
|
|
simpleInfo.CollectionID = 1000
|
|
segmentChangeInfos.Infos[0].OnlineSegments = append(segmentChangeInfos.Infos[0].OnlineSegments, simpleInfo)
|
|
err = qc.adjustByChangeInfo(segmentChangeInfos)
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("test no segment when adjustByChangeInfo", func(t *testing.T) {
|
|
qc, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.Nil(t, err)
|
|
|
|
err = qc.historical.replica.removeSegment(defaultSegmentID)
|
|
assert.NoError(t, err)
|
|
|
|
segmentChangeInfos := genSimpleSealedSegmentsChangeInfoMsg()
|
|
segmentChangeInfos.Infos[0].OfflineSegments = append(segmentChangeInfos.Infos[0].OfflineSegments, genSimpleSegmentInfo())
|
|
|
|
err = qc.adjustByChangeInfo(segmentChangeInfos)
|
|
assert.Nil(t, err)
|
|
})
|
|
}
|
|
|
|
func TestQueryCollection_search_while_release(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
t.Run("test search while release collection", func(t *testing.T) {
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
queryChannel := genQueryChannel()
|
|
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
|
queryCollection.queryResultMsgStream.Start()
|
|
|
|
msg, err := genSimpleSearchMsg()
|
|
assert.NoError(t, err)
|
|
|
|
// To prevent data race in search trackCtx
|
|
searchMu := &sync.Mutex{}
|
|
|
|
runSearchWhileReleaseCollection := func(wg *sync.WaitGroup) {
|
|
go func() {
|
|
_ = queryCollection.streaming.replica.removeCollection(defaultCollectionID)
|
|
wg.Done()
|
|
}()
|
|
|
|
go func() {
|
|
searchMu.Lock()
|
|
_ = queryCollection.search(msg)
|
|
searchMu.Unlock()
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
wg := &sync.WaitGroup{}
|
|
for i := 0; i < 10; i++ {
|
|
log.Debug("runSearchWhileReleaseCollection", zap.Any("time", i))
|
|
wg.Add(2)
|
|
go runSearchWhileReleaseCollection(wg)
|
|
}
|
|
wg.Wait()
|
|
})
|
|
|
|
t.Run("test search while release partition", func(t *testing.T) {
|
|
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
|
assert.NoError(t, err)
|
|
|
|
queryChannel := genQueryChannel()
|
|
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
|
queryCollection.queryResultMsgStream.Start()
|
|
|
|
msg, err := genSimpleSearchMsg()
|
|
assert.NoError(t, err)
|
|
|
|
// To prevent data race in search trackCtx
|
|
searchMu := &sync.Mutex{}
|
|
|
|
runSearchWhileReleasePartition := func(wg *sync.WaitGroup) {
|
|
go func() {
|
|
_ = queryCollection.streaming.replica.removePartition(defaultPartitionID)
|
|
wg.Done()
|
|
}()
|
|
|
|
go func() {
|
|
searchMu.Lock()
|
|
_ = queryCollection.search(msg)
|
|
searchMu.Unlock()
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
wg := &sync.WaitGroup{}
|
|
for i := 0; i < 10; i++ {
|
|
log.Debug("runSearchWhileReleasePartition", zap.Any("time", i))
|
|
wg.Add(2)
|
|
go runSearchWhileReleasePartition(wg)
|
|
}
|
|
wg.Wait()
|
|
})
|
|
}
|