Add test cases for search task (#7796)

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
pull/7812/head
dragondriver 2021-09-13 17:12:19 +08:00 committed by GitHub
parent 85d73358cc
commit 8d239b6473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 518 additions and 82 deletions

View File

@ -313,7 +313,7 @@ from each task concurrently.
The following figure is a schematic diagram of taskScheduer's scheduling of DqQueue.
![task_scheduler_1](./graphs/task_scheduler_1.png)
![task_scheduler_1](./graphs/task_scheduler_2.png)
The tasks in DqQueue can be scheduled in parallel. In a scheduling process, taskScheduler will execute several tasks
concurrently.

27
internal/common/common.go Normal file
View File

@ -0,0 +1,27 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package common
// system filed id:
// 0: unique row id
// 1: timestamp
// 100: first user field id
// 101: second user field id
// 102: ...
const (
StartOfUserFieldID = 100
RowIDField = 0
TimeStampField = 1
RowIDFieldName = "RowID"
TimeStampFieldName = "Timestamp"
)

View File

@ -13,11 +13,19 @@ package proxy
import (
"context"
"errors"
"fmt"
"runtime"
"sort"
"sync"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
"github.com/milvus-io/milvus/internal/log"
@ -41,6 +49,77 @@ type channelsMgr interface {
type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error)
type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error)
func getDmlChannelsFunc(ctx context.Context, rc types.RootCoord) getChannelsFuncType {
return func(collectionID UniqueID) (map[vChan]pChan, error) {
req := &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 0, // todo
Timestamp: 0, // todo
SourceID: 0, // todo
},
DbName: "", // todo
CollectionName: "", // todo
CollectionID: collectionID,
TimeStamp: 0, // todo
}
resp, err := rc.DescribeCollection(ctx, req)
if err != nil {
log.Warn("DescribeCollection", zap.Error(err))
return nil, err
}
if resp.Status.ErrorCode != 0 {
log.Warn("DescribeCollection",
zap.Any("ErrorCode", resp.Status.ErrorCode),
zap.Any("Reason", resp.Status.Reason))
return nil, err
}
if len(resp.VirtualChannelNames) != len(resp.PhysicalChannelNames) {
err := fmt.Errorf(
"len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v",
len(resp.VirtualChannelNames),
len(resp.PhysicalChannelNames))
log.Warn("GetDmlChannels", zap.Error(err))
return nil, err
}
ret := make(map[vChan]pChan)
for idx, name := range resp.VirtualChannelNames {
if _, ok := ret[name]; ok {
err := fmt.Errorf(
"duplicated virtual channel found, vchan: %v, pchan: %v",
name,
resp.PhysicalChannelNames[idx])
return nil, err
}
ret[name] = resp.PhysicalChannelNames[idx]
}
return ret, nil
}
}
func getDqlChannelsFunc(ctx context.Context, proxyID int64, qc createQueryChannelInterface) getChannelsFuncType {
return func(collectionID UniqueID) (map[vChan]pChan, error) {
req := &querypb.CreateQueryChannelRequest{
CollectionID: collectionID,
ProxyID: proxyID,
}
resp, err := qc.CreateQueryChannel(ctx, req)
if err != nil {
return nil, err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, errors.New(resp.Status.Reason)
}
m := make(map[vChan]pChan)
m[resp.RequestChannel] = resp.RequestChannel
return m, nil
}
}
type streamType int
const (

View File

@ -14,6 +14,8 @@ package proxy
import (
"context"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
)
@ -38,3 +40,13 @@ type timestampAllocatorInterface interface {
type getChannelsService interface {
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
}
// queryCoordShowCollectionsInterface used in searchTask & queryTask
type queryCoordShowCollectionsInterface interface {
ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
}
// createQueryChannelInterface defines CreateQueryChannel
type createQueryChannelInterface interface {
CreateQueryChannel(ctx context.Context, request *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error)
}

View File

@ -18,6 +18,8 @@ import (
"sync"
"time"
"github.com/milvus-io/milvus/internal/common"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
@ -294,7 +296,7 @@ func (m *MetaCache) describeCollection(ctx context.Context, collectionName strin
CreatedUtcTimestamp: coll.CreatedUtcTimestamp,
}
for _, field := range coll.Schema.Fields {
if field.FieldID >= 100 { // TODO(dragondriver): use StartOfUserField to replace 100
if field.FieldID >= common.StartOfUserFieldID {
resp.Schema.Fields = append(resp.Schema.Fields, field)
}
}

View File

@ -17,6 +17,8 @@ import (
"sync"
"time"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/msgstream"
@ -553,3 +555,33 @@ func generateHashKeys(numRows int) []uint32 {
}
return ret
}
type mockQueryCoordShowCollectionsInterface struct {
collectionIDs []int64
inMemoryPercentages []int64
}
func (ins *mockQueryCoordShowCollectionsInterface) ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
resp := &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
CollectionIDs: ins.collectionIDs,
InMemoryPercentages: ins.inMemoryPercentages,
}
return resp, nil
}
func (ins *mockQueryCoordShowCollectionsInterface) addCollection(collectionID int64, inMemoryPercentage int64) {
ins.collectionIDs = append(ins.collectionIDs, collectionID)
ins.inMemoryPercentages = append(ins.inMemoryPercentages, collectionID)
}
func newMockQueryCoordShowCollectionsInterface() *mockQueryCoordShowCollectionsInterface {
return &mockQueryCoordShowCollectionsInterface{
collectionIDs: make([]int64, 0),
inMemoryPercentages: make([]int64, 0),
}
}

View File

@ -14,7 +14,6 @@ package proxy
import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"sync/atomic"
@ -31,7 +30,6 @@ import (
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
@ -187,73 +185,9 @@ func (node *Proxy) Init() error {
node.segAssigner = segAssigner
node.segAssigner.PeerID = Params.ProxyID
getDmlChannelsFunc := func(collectionID UniqueID) (map[vChan]pChan, error) {
req := &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 0, // todo
Timestamp: 0, // todo
SourceID: 0, // todo
},
DbName: "", // todo
CollectionName: "", // todo
CollectionID: collectionID,
TimeStamp: 0, // todo
}
resp, err := node.rootCoord.DescribeCollection(node.ctx, req)
if err != nil {
log.Warn("DescribeCollection", zap.Error(err))
return nil, err
}
if resp.Status.ErrorCode != 0 {
log.Warn("DescribeCollection",
zap.Any("ErrorCode", resp.Status.ErrorCode),
zap.Any("Reason", resp.Status.Reason))
return nil, err
}
if len(resp.VirtualChannelNames) != len(resp.PhysicalChannelNames) {
err := fmt.Errorf(
"len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v",
len(resp.VirtualChannelNames),
len(resp.PhysicalChannelNames))
log.Warn("GetDmlChannels", zap.Error(err))
return nil, err
}
ret := make(map[vChan]pChan)
for idx, name := range resp.VirtualChannelNames {
if _, ok := ret[name]; ok {
err := fmt.Errorf(
"duplicated virtual channel found, vchan: %v, pchan: %v",
name,
resp.PhysicalChannelNames[idx])
return nil, err
}
ret[name] = resp.PhysicalChannelNames[idx]
}
return ret, nil
}
getDqlChannelsFunc := func(collectionID UniqueID) (map[vChan]pChan, error) {
req := &querypb.CreateQueryChannelRequest{
CollectionID: collectionID,
ProxyID: node.session.ServerID,
}
resp, err := node.queryCoord.CreateQueryChannel(node.ctx, req)
if err != nil {
return nil, err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, errors.New(resp.Status.Reason)
}
m := make(map[vChan]pChan)
m[resp.RequestChannel] = resp.RequestChannel
return m, nil
}
chMgr := newChannelsMgrImpl(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.rootCoord)
dqlChannelsFunc := getDqlChannelsFunc(node.ctx, node.session.ServerID, node.queryCoord)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, dqlChannelsFunc, nil, node.msFactory)
node.chMgr = chMgr
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)

View File

@ -19,6 +19,8 @@ import (
"sync/atomic"
"time"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"go.uber.org/zap"
@ -67,6 +69,16 @@ type partitionMap struct {
partitionID2Meta map[typeutil.UniqueID]partitionMeta
}
type RootCoordMockOption func(mock *RootCoordMock)
type describeCollectionFuncType func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
func SetDescribeCollectionFunc(f describeCollectionFuncType) RootCoordMockOption {
return func(mock *RootCoordMock) {
mock.SetDescribeCollectionFunc(f)
}
}
type RootCoordMock struct {
nodeID typeutil.UniqueID
address string
@ -85,6 +97,8 @@ type RootCoordMock struct {
collID2Partitions map[typeutil.UniqueID]partitionMap
partitionMtx sync.RWMutex
describeCollectionFunc describeCollectionFuncType
// TODO(dragondriver): index-related
// TODO(dragondriver): segment-related
@ -205,6 +219,9 @@ func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb.
Reason: fmt.Sprintf("failed to parse schema, error: %v", err),
}, nil
}
for i := range schema.Fields {
schema.Fields[i].FieldID = int64(common.StartOfUserFieldID + i)
}
collID := typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
coord.collName2ID[req.CollectionName] = collID
@ -311,6 +328,14 @@ func (coord *RootCoordMock) HasCollection(ctx context.Context, req *milvuspb.Has
}, nil
}
func (coord *RootCoordMock) SetDescribeCollectionFunc(f describeCollectionFuncType) {
coord.describeCollectionFunc = f
}
func (coord *RootCoordMock) ResetDescribeCollectionFunc(f describeCollectionFuncType) {
coord.describeCollectionFunc = nil
}
func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
code := coord.state.Load().(internalpb.StateCode)
if code != internalpb.StateCode_Healthy {
@ -323,6 +348,11 @@ func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvusp
CollectionID: 0,
}, nil
}
if coord.describeCollectionFunc != nil {
return coord.describeCollectionFunc(ctx, req)
}
coord.collMtx.RLock()
defer coord.collMtx.RUnlock()
@ -828,8 +858,8 @@ func (coord *RootCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMet
}, nil
}
func NewRootCoordMock() *RootCoordMock {
return &RootCoordMock{
func NewRootCoordMock(opts ...RootCoordMockOption) *RootCoordMock {
rc := &RootCoordMock{
nodeID: typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
address: funcutil.GenRandomStr(), // TODO(dragondriver): random address
statisticsChannel: funcutil.GenRandomStr(),
@ -839,4 +869,10 @@ func NewRootCoordMock() *RootCoordMock {
collID2Partitions: make(map[typeutil.UniqueID]partitionMap),
lastTs: typeutil.Timestamp(time.Now().UnixNano()),
}
for _, opt := range opts {
opt(rc)
}
return rc
}

View File

@ -27,6 +27,8 @@ import (
"time"
"unsafe"
"github.com/milvus-io/milvus/internal/common"
"go.uber.org/zap"
"github.com/golang/protobuf/proto"
@ -1348,7 +1350,7 @@ type searchTask struct {
result *milvuspb.SearchResults
query *milvuspb.SearchRequest
chMgr channelsMgr
qc types.QueryCoord
qc queryCoordShowCollectionsInterface
}
func (st *searchTask) TraceCtx() context.Context {
@ -1385,6 +1387,8 @@ func (st *searchTask) SetTs(ts Timestamp) {
func (st *searchTask) OnEnqueue() error {
st.Base = &commonpb.MsgBase{}
st.Base.MsgType = commonpb.MsgType_Search
st.Base.SourceID = Params.ProxyID
return nil
}
@ -1513,10 +1517,22 @@ func (st *searchTask) PreExecute(ctx context.Context) error {
SearchParams: searchParams,
}
log.Debug("create query plan",
//zap.Any("schema", schema),
zap.String("dsl", st.query.Dsl),
zap.String("anns field", annsField),
zap.Any("query info", queryInfo))
plan, err := CreateQueryPlan(schema, st.query.Dsl, annsField, queryInfo)
if err != nil {
//return errors.New("invalid expression: " + st.query.Dsl)
return err
log.Debug("failed to create query plan",
zap.Error(err),
//zap.Any("schema", schema),
zap.String("dsl", st.query.Dsl),
zap.String("anns field", annsField),
zap.Any("query info", queryInfo))
return fmt.Errorf("failed to create query plan: %v", err)
}
for _, name := range st.query.OutputFields {
hitField := false
@ -2686,7 +2702,7 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error {
dct.result.CreatedUtcTimestamp = result.CreatedUtcTimestamp
for _, field := range result.Schema.Fields {
if field.FieldID >= 100 { // TODO(dragondriver): use StartOfUserFieldID replacing 100
if field.FieldID >= common.StartOfUserFieldID {
dct.result.Schema.Fields = append(dct.result.Schema.Fields, &schemapb.FieldSchema{
FieldID: field.FieldID,
Name: field.Name,

View File

@ -1,11 +1,21 @@
package proxy
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"math/rand"
"strconv"
"sync"
"testing"
"time"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
@ -61,6 +71,85 @@ func constructCollectionSchema(
}
}
func constructPlaceholderGroup(
nq, dim int,
) *milvuspb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, binary.LittleEndian, f)
if err != nil {
panic(err)
}
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &milvuspb.PlaceholderGroup{
Placeholders: []*milvuspb.PlaceholderValue{
{
Tag: "$0",
Type: milvuspb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
func constructSearchRequest(
dbName, collectionName string,
expr string,
floatVecField string,
nq, dim, nprobe, topk int,
) *milvuspb.SearchRequest {
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
if err != nil {
panic(err)
}
plg := constructPlaceholderGroup(nq, dim)
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: []*commonpb.KeyValuePair{
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: strconv.Itoa(topk),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
func TestGetNumRowsOfScalarField(t *testing.T) {
cases := []struct {
datas interface{}
@ -668,6 +757,7 @@ func TestCreateCollectionTask(t *testing.T) {
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
shardsNum := int32(2)
prefix := "TestCreateCollectionTask"
@ -904,6 +994,7 @@ func TestDropCollectionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
InitMetaCache(rc)
@ -988,6 +1079,7 @@ func TestHasCollectionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
InitMetaCache(rc)
prefix := "TestHasCollectionTask"
@ -1069,6 +1161,7 @@ func TestDescribeCollectionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
InitMetaCache(rc)
prefix := "TestDescribeCollectionTask"
@ -1120,6 +1213,7 @@ func TestCreatePartitionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
prefix := "TestCreatePartitionTask"
dbName := ""
@ -1166,6 +1260,7 @@ func TestDropPartitionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
prefix := "TestDropPartitionTask"
dbName := ""
@ -1212,6 +1307,7 @@ func TestHasPartitionTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
prefix := "TestHasPartitionTask"
dbName := ""
@ -1258,6 +1354,7 @@ func TestShowPartitionsTask(t *testing.T) {
Params.Init()
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
prefix := "TestShowPartitionsTask"
dbName := ""
@ -1308,3 +1405,202 @@ func TestShowPartitionsTask(t *testing.T) {
assert.NotNil(t, err)
}
func TestSearchTask_all(t *testing.T) {
var err error
Params.Init()
Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
rc.Start()
defer rc.Stop()
ctx := context.Background()
err = InitMetaCache(rc)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestSearchTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
expr := fmt.Sprintf("%s > 0", int64Field)
nq := 10
topk := 10
nprobe := 10
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc := newMockQueryCoordShowCollectionsInterface()
qc.addCollection(collectionID, 100)
req := constructSearchRequest(dbName, collectionName,
expr,
floatVecField,
nq, dim, nprobe, topk)
task := &searchTask{
Condition: NewTaskCondition(ctx),
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyID,
},
ResultChannelID: strconv.FormatInt(Params.ProxyID, 10),
DbID: 0,
CollectionID: 0,
PartitionIDs: nil,
Dsl: "",
PlaceholderGroup: nil,
DslType: 0,
SerializedExprPlan: nil,
OutputFieldsId: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
},
ctx: ctx,
resultBuf: make(chan []*internalpb.SearchResults),
result: nil,
query: req,
chMgr: chMgr,
qc: qc,
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack := <-stream.Chan():
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.SearchMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
constructSearchResulstData := func() *schemapb.SearchResultData {
resultData := &schemapb.SearchResultData{
NumQueries: int64(nq),
TopK: int64(topk),
FieldsData: nil,
Scores: make([]float32, nq*topk),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: make([]int64, nq*topk),
},
},
},
Topks: make([]int64, nq),
}
// ids := make([]int64, topk)
// for i := 0; i < topk; i++ {
// ids[i] = int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
// }
for i := 0; i < nq; i++ {
for j := 0; j < topk; j++ {
offset := i*topk + j
score := rand.Float32()
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
resultData.Scores[offset] = score
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
}
resultData.Topks[i] = int64(topk)
}
return resultData
}
result1 := &internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: "",
MetricType: distance.L2,
NumQueries: int64(nq),
TopK: int64(topk),
SealedSegmentIDsSearched: nil,
ChannelIDsSearched: nil,
GlobalSealedSegmentIDs: nil,
SlicedBlob: nil,
SlicedNumCount: 1,
SlicedOffset: 0,
}
resultData := constructSearchResulstData()
sliceBlob, err := proto.Marshal(resultData)
assert.NoError(t, err)
result1.SlicedBlob = sliceBlob
// send search result
task.resultBuf <- []*internalpb.SearchResults{result1}
}
}
}
}()
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
cancel()
wg.Wait()
}

View File

@ -11,6 +11,8 @@
package rootcoord
import "github.com/milvus-io/milvus/internal/common"
// system filed id:
// 0: unique row id
// 1: timestamp
@ -19,9 +21,9 @@ package rootcoord
// 102: ...
const (
StartOfUserFieldID = 100
RowIDField = 0
TimeStampField = 1
RowIDFieldName = "RowID"
TimeStampFieldName = "Timestamp"
StartOfUserFieldID = common.StartOfUserFieldID
RowIDField = common.RowIDField
TimeStampField = common.TimeStampField
RowIDFieldName = common.RowIDFieldName
TimeStampFieldName = common.TimeStampFieldName
)