mirror of https://github.com/milvus-io/milvus.git
Add test cases for search task (#7796)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/7812/head
parent
85d73358cc
commit
8d239b6473
|
@ -313,7 +313,7 @@ from each task concurrently.
|
|||
|
||||
The following figure is a schematic diagram of taskScheduer's scheduling of DqQueue.
|
||||
|
||||

|
||||

|
||||
|
||||
The tasks in DqQueue can be scheduled in parallel. In a scheduling process, taskScheduler will execute several tasks
|
||||
concurrently.
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package common
|
||||
|
||||
// system filed id:
|
||||
// 0: unique row id
|
||||
// 1: timestamp
|
||||
// 100: first user field id
|
||||
// 101: second user field id
|
||||
// 102: ...
|
||||
|
||||
const (
|
||||
StartOfUserFieldID = 100
|
||||
RowIDField = 0
|
||||
TimeStampField = 1
|
||||
RowIDFieldName = "RowID"
|
||||
TimeStampFieldName = "Timestamp"
|
||||
)
|
|
@ -13,11 +13,19 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
|
@ -41,6 +49,77 @@ type channelsMgr interface {
|
|||
type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error)
|
||||
type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error)
|
||||
|
||||
func getDmlChannelsFunc(ctx context.Context, rc types.RootCoord) getChannelsFuncType {
|
||||
return func(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
req := &milvuspb.DescribeCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_DescribeCollection,
|
||||
MsgID: 0, // todo
|
||||
Timestamp: 0, // todo
|
||||
SourceID: 0, // todo
|
||||
},
|
||||
DbName: "", // todo
|
||||
CollectionName: "", // todo
|
||||
CollectionID: collectionID,
|
||||
TimeStamp: 0, // todo
|
||||
}
|
||||
resp, err := rc.DescribeCollection(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("DescribeCollection", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if resp.Status.ErrorCode != 0 {
|
||||
log.Warn("DescribeCollection",
|
||||
zap.Any("ErrorCode", resp.Status.ErrorCode),
|
||||
zap.Any("Reason", resp.Status.Reason))
|
||||
return nil, err
|
||||
}
|
||||
if len(resp.VirtualChannelNames) != len(resp.PhysicalChannelNames) {
|
||||
err := fmt.Errorf(
|
||||
"len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v",
|
||||
len(resp.VirtualChannelNames),
|
||||
len(resp.PhysicalChannelNames))
|
||||
log.Warn("GetDmlChannels", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := make(map[vChan]pChan)
|
||||
for idx, name := range resp.VirtualChannelNames {
|
||||
if _, ok := ret[name]; ok {
|
||||
err := fmt.Errorf(
|
||||
"duplicated virtual channel found, vchan: %v, pchan: %v",
|
||||
name,
|
||||
resp.PhysicalChannelNames[idx])
|
||||
return nil, err
|
||||
}
|
||||
ret[name] = resp.PhysicalChannelNames[idx]
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getDqlChannelsFunc(ctx context.Context, proxyID int64, qc createQueryChannelInterface) getChannelsFuncType {
|
||||
return func(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
req := &querypb.CreateQueryChannelRequest{
|
||||
CollectionID: collectionID,
|
||||
ProxyID: proxyID,
|
||||
}
|
||||
resp, err := qc.CreateQueryChannel(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return nil, errors.New(resp.Status.Reason)
|
||||
}
|
||||
|
||||
m := make(map[vChan]pChan)
|
||||
m[resp.RequestChannel] = resp.RequestChannel
|
||||
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
type streamType int
|
||||
|
||||
const (
|
||||
|
|
|
@ -14,6 +14,8 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
)
|
||||
|
||||
|
@ -38,3 +40,13 @@ type timestampAllocatorInterface interface {
|
|||
type getChannelsService interface {
|
||||
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
|
||||
}
|
||||
|
||||
// queryCoordShowCollectionsInterface used in searchTask & queryTask
|
||||
type queryCoordShowCollectionsInterface interface {
|
||||
ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
|
||||
}
|
||||
|
||||
// createQueryChannelInterface defines CreateQueryChannel
|
||||
type createQueryChannelInterface interface {
|
||||
CreateQueryChannel(ctx context.Context, request *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
|
@ -294,7 +296,7 @@ func (m *MetaCache) describeCollection(ctx context.Context, collectionName strin
|
|||
CreatedUtcTimestamp: coll.CreatedUtcTimestamp,
|
||||
}
|
||||
for _, field := range coll.Schema.Fields {
|
||||
if field.FieldID >= 100 { // TODO(dragondriver): use StartOfUserField to replace 100
|
||||
if field.FieldID >= common.StartOfUserFieldID {
|
||||
resp.Schema.Fields = append(resp.Schema.Fields, field)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
|
@ -553,3 +555,33 @@ func generateHashKeys(numRows int) []uint32 {
|
|||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
type mockQueryCoordShowCollectionsInterface struct {
|
||||
collectionIDs []int64
|
||||
inMemoryPercentages []int64
|
||||
}
|
||||
|
||||
func (ins *mockQueryCoordShowCollectionsInterface) ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
resp := &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
CollectionIDs: ins.collectionIDs,
|
||||
InMemoryPercentages: ins.inMemoryPercentages,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (ins *mockQueryCoordShowCollectionsInterface) addCollection(collectionID int64, inMemoryPercentage int64) {
|
||||
ins.collectionIDs = append(ins.collectionIDs, collectionID)
|
||||
ins.inMemoryPercentages = append(ins.inMemoryPercentages, collectionID)
|
||||
}
|
||||
|
||||
func newMockQueryCoordShowCollectionsInterface() *mockQueryCoordShowCollectionsInterface {
|
||||
return &mockQueryCoordShowCollectionsInterface{
|
||||
collectionIDs: make([]int64, 0),
|
||||
inMemoryPercentages: make([]int64, 0),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -31,7 +30,6 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
@ -187,73 +185,9 @@ func (node *Proxy) Init() error {
|
|||
node.segAssigner = segAssigner
|
||||
node.segAssigner.PeerID = Params.ProxyID
|
||||
|
||||
getDmlChannelsFunc := func(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
req := &milvuspb.DescribeCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_DescribeCollection,
|
||||
MsgID: 0, // todo
|
||||
Timestamp: 0, // todo
|
||||
SourceID: 0, // todo
|
||||
},
|
||||
DbName: "", // todo
|
||||
CollectionName: "", // todo
|
||||
CollectionID: collectionID,
|
||||
TimeStamp: 0, // todo
|
||||
}
|
||||
resp, err := node.rootCoord.DescribeCollection(node.ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("DescribeCollection", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if resp.Status.ErrorCode != 0 {
|
||||
log.Warn("DescribeCollection",
|
||||
zap.Any("ErrorCode", resp.Status.ErrorCode),
|
||||
zap.Any("Reason", resp.Status.Reason))
|
||||
return nil, err
|
||||
}
|
||||
if len(resp.VirtualChannelNames) != len(resp.PhysicalChannelNames) {
|
||||
err := fmt.Errorf(
|
||||
"len(VirtualChannelNames): %v, len(PhysicalChannelNames): %v",
|
||||
len(resp.VirtualChannelNames),
|
||||
len(resp.PhysicalChannelNames))
|
||||
log.Warn("GetDmlChannels", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := make(map[vChan]pChan)
|
||||
for idx, name := range resp.VirtualChannelNames {
|
||||
if _, ok := ret[name]; ok {
|
||||
err := fmt.Errorf(
|
||||
"duplicated virtual channel found, vchan: %v, pchan: %v",
|
||||
name,
|
||||
resp.PhysicalChannelNames[idx])
|
||||
return nil, err
|
||||
}
|
||||
ret[name] = resp.PhysicalChannelNames[idx]
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
getDqlChannelsFunc := func(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
req := &querypb.CreateQueryChannelRequest{
|
||||
CollectionID: collectionID,
|
||||
ProxyID: node.session.ServerID,
|
||||
}
|
||||
resp, err := node.queryCoord.CreateQueryChannel(node.ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return nil, errors.New(resp.Status.Reason)
|
||||
}
|
||||
|
||||
m := make(map[vChan]pChan)
|
||||
m[resp.RequestChannel] = resp.RequestChannel
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
chMgr := newChannelsMgrImpl(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
|
||||
dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.rootCoord)
|
||||
dqlChannelsFunc := getDqlChannelsFunc(node.ctx, node.session.ServerID, node.queryCoord)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, dqlChannelsFunc, nil, node.msFactory)
|
||||
node.chMgr = chMgr
|
||||
|
||||
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
|
||||
|
|
|
@ -19,6 +19,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"go.uber.org/zap"
|
||||
|
@ -67,6 +69,16 @@ type partitionMap struct {
|
|||
partitionID2Meta map[typeutil.UniqueID]partitionMeta
|
||||
}
|
||||
|
||||
type RootCoordMockOption func(mock *RootCoordMock)
|
||||
|
||||
type describeCollectionFuncType func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
|
||||
|
||||
func SetDescribeCollectionFunc(f describeCollectionFuncType) RootCoordMockOption {
|
||||
return func(mock *RootCoordMock) {
|
||||
mock.SetDescribeCollectionFunc(f)
|
||||
}
|
||||
}
|
||||
|
||||
type RootCoordMock struct {
|
||||
nodeID typeutil.UniqueID
|
||||
address string
|
||||
|
@ -85,6 +97,8 @@ type RootCoordMock struct {
|
|||
collID2Partitions map[typeutil.UniqueID]partitionMap
|
||||
partitionMtx sync.RWMutex
|
||||
|
||||
describeCollectionFunc describeCollectionFuncType
|
||||
|
||||
// TODO(dragondriver): index-related
|
||||
|
||||
// TODO(dragondriver): segment-related
|
||||
|
@ -205,6 +219,9 @@ func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb.
|
|||
Reason: fmt.Sprintf("failed to parse schema, error: %v", err),
|
||||
}, nil
|
||||
}
|
||||
for i := range schema.Fields {
|
||||
schema.Fields[i].FieldID = int64(common.StartOfUserFieldID + i)
|
||||
}
|
||||
|
||||
collID := typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
coord.collName2ID[req.CollectionName] = collID
|
||||
|
@ -311,6 +328,14 @@ func (coord *RootCoordMock) HasCollection(ctx context.Context, req *milvuspb.Has
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (coord *RootCoordMock) SetDescribeCollectionFunc(f describeCollectionFuncType) {
|
||||
coord.describeCollectionFunc = f
|
||||
}
|
||||
|
||||
func (coord *RootCoordMock) ResetDescribeCollectionFunc(f describeCollectionFuncType) {
|
||||
coord.describeCollectionFunc = nil
|
||||
}
|
||||
|
||||
func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
code := coord.state.Load().(internalpb.StateCode)
|
||||
if code != internalpb.StateCode_Healthy {
|
||||
|
@ -323,6 +348,11 @@ func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvusp
|
|||
CollectionID: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if coord.describeCollectionFunc != nil {
|
||||
return coord.describeCollectionFunc(ctx, req)
|
||||
}
|
||||
|
||||
coord.collMtx.RLock()
|
||||
defer coord.collMtx.RUnlock()
|
||||
|
||||
|
@ -828,8 +858,8 @@ func (coord *RootCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMet
|
|||
}, nil
|
||||
}
|
||||
|
||||
func NewRootCoordMock() *RootCoordMock {
|
||||
return &RootCoordMock{
|
||||
func NewRootCoordMock(opts ...RootCoordMockOption) *RootCoordMock {
|
||||
rc := &RootCoordMock{
|
||||
nodeID: typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
|
||||
address: funcutil.GenRandomStr(), // TODO(dragondriver): random address
|
||||
statisticsChannel: funcutil.GenRandomStr(),
|
||||
|
@ -839,4 +869,10 @@ func NewRootCoordMock() *RootCoordMock {
|
|||
collID2Partitions: make(map[typeutil.UniqueID]partitionMap),
|
||||
lastTs: typeutil.Timestamp(time.Now().UnixNano()),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(rc)
|
||||
}
|
||||
|
||||
return rc
|
||||
}
|
||||
|
|
|
@ -27,6 +27,8 @@ import (
|
|||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -1348,7 +1350,7 @@ type searchTask struct {
|
|||
result *milvuspb.SearchResults
|
||||
query *milvuspb.SearchRequest
|
||||
chMgr channelsMgr
|
||||
qc types.QueryCoord
|
||||
qc queryCoordShowCollectionsInterface
|
||||
}
|
||||
|
||||
func (st *searchTask) TraceCtx() context.Context {
|
||||
|
@ -1385,6 +1387,8 @@ func (st *searchTask) SetTs(ts Timestamp) {
|
|||
|
||||
func (st *searchTask) OnEnqueue() error {
|
||||
st.Base = &commonpb.MsgBase{}
|
||||
st.Base.MsgType = commonpb.MsgType_Search
|
||||
st.Base.SourceID = Params.ProxyID
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1513,10 +1517,22 @@ func (st *searchTask) PreExecute(ctx context.Context) error {
|
|||
SearchParams: searchParams,
|
||||
}
|
||||
|
||||
log.Debug("create query plan",
|
||||
//zap.Any("schema", schema),
|
||||
zap.String("dsl", st.query.Dsl),
|
||||
zap.String("anns field", annsField),
|
||||
zap.Any("query info", queryInfo))
|
||||
|
||||
plan, err := CreateQueryPlan(schema, st.query.Dsl, annsField, queryInfo)
|
||||
if err != nil {
|
||||
//return errors.New("invalid expression: " + st.query.Dsl)
|
||||
return err
|
||||
log.Debug("failed to create query plan",
|
||||
zap.Error(err),
|
||||
//zap.Any("schema", schema),
|
||||
zap.String("dsl", st.query.Dsl),
|
||||
zap.String("anns field", annsField),
|
||||
zap.Any("query info", queryInfo))
|
||||
|
||||
return fmt.Errorf("failed to create query plan: %v", err)
|
||||
}
|
||||
for _, name := range st.query.OutputFields {
|
||||
hitField := false
|
||||
|
@ -2686,7 +2702,7 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error {
|
|||
dct.result.CreatedUtcTimestamp = result.CreatedUtcTimestamp
|
||||
|
||||
for _, field := range result.Schema.Fields {
|
||||
if field.FieldID >= 100 { // TODO(dragondriver): use StartOfUserFieldID replacing 100
|
||||
if field.FieldID >= common.StartOfUserFieldID {
|
||||
dct.result.Schema.Fields = append(dct.result.Schema.Fields, &schemapb.FieldSchema{
|
||||
FieldID: field.FieldID,
|
||||
Name: field.Name,
|
||||
|
|
|
@ -1,11 +1,21 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/distance"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
|
@ -61,6 +71,85 @@ func constructCollectionSchema(
|
|||
}
|
||||
}
|
||||
|
||||
func constructPlaceholderGroup(
|
||||
nq, dim int,
|
||||
) *milvuspb.PlaceholderGroup {
|
||||
values := make([][]byte, 0, nq)
|
||||
for i := 0; i < nq; i++ {
|
||||
bs := make([]byte, 0, dim*4)
|
||||
for j := 0; j < dim; j++ {
|
||||
var buffer bytes.Buffer
|
||||
f := rand.Float32()
|
||||
err := binary.Write(&buffer, binary.LittleEndian, f)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bs = append(bs, buffer.Bytes()...)
|
||||
}
|
||||
values = append(values, bs)
|
||||
}
|
||||
|
||||
return &milvuspb.PlaceholderGroup{
|
||||
Placeholders: []*milvuspb.PlaceholderValue{
|
||||
{
|
||||
Tag: "$0",
|
||||
Type: milvuspb.PlaceholderType_FloatVector,
|
||||
Values: values,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func constructSearchRequest(
|
||||
dbName, collectionName string,
|
||||
expr string,
|
||||
floatVecField string,
|
||||
nq, dim, nprobe, topk int,
|
||||
) *milvuspb.SearchRequest {
|
||||
params := make(map[string]string)
|
||||
params["nprobe"] = strconv.Itoa(nprobe)
|
||||
b, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
plg := constructPlaceholderGroup(nq, dim)
|
||||
plgBs, err := proto.Marshal(plg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &milvuspb.SearchRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: nil,
|
||||
Dsl: expr,
|
||||
PlaceholderGroup: plgBs,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: nil,
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
Value: string(b),
|
||||
},
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: strconv.Itoa(topk),
|
||||
},
|
||||
},
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNumRowsOfScalarField(t *testing.T) {
|
||||
cases := []struct {
|
||||
datas interface{}
|
||||
|
@ -668,6 +757,7 @@ func TestCreateCollectionTask(t *testing.T) {
|
|||
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
shardsNum := int32(2)
|
||||
prefix := "TestCreateCollectionTask"
|
||||
|
@ -904,6 +994,7 @@ func TestDropCollectionTask(t *testing.T) {
|
|||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc)
|
||||
|
||||
|
@ -988,6 +1079,7 @@ func TestHasCollectionTask(t *testing.T) {
|
|||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc)
|
||||
prefix := "TestHasCollectionTask"
|
||||
|
@ -1069,6 +1161,7 @@ func TestDescribeCollectionTask(t *testing.T) {
|
|||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
InitMetaCache(rc)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
|
@ -1120,6 +1213,7 @@ func TestCreatePartitionTask(t *testing.T) {
|
|||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
prefix := "TestCreatePartitionTask"
|
||||
dbName := ""
|
||||
|
@ -1166,6 +1260,7 @@ func TestDropPartitionTask(t *testing.T) {
|
|||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
prefix := "TestDropPartitionTask"
|
||||
dbName := ""
|
||||
|
@ -1212,6 +1307,7 @@ func TestHasPartitionTask(t *testing.T) {
|
|||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
prefix := "TestHasPartitionTask"
|
||||
dbName := ""
|
||||
|
@ -1258,6 +1354,7 @@ func TestShowPartitionsTask(t *testing.T) {
|
|||
Params.Init()
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
ctx := context.Background()
|
||||
prefix := "TestShowPartitionsTask"
|
||||
dbName := ""
|
||||
|
@ -1308,3 +1405,202 @@ func TestShowPartitionsTask(t *testing.T) {
|
|||
assert.NotNil(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestSearchTask_all(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Params.Init()
|
||||
Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
|
||||
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = InitMetaCache(rc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := int32(2)
|
||||
prefix := "TestSearchTask_all"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
nprobe := 10
|
||||
|
||||
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
query := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
defer chMgr.removeAllDQLStream()
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
qc := newMockQueryCoordShowCollectionsInterface()
|
||||
qc.addCollection(collectionID, 100)
|
||||
|
||||
req := constructSearchRequest(dbName, collectionName,
|
||||
expr,
|
||||
floatVecField,
|
||||
nq, dim, nprobe, topk)
|
||||
|
||||
task := &searchTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Search,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: Params.ProxyID,
|
||||
},
|
||||
ResultChannelID: strconv.FormatInt(Params.ProxyID, 10),
|
||||
DbID: 0,
|
||||
CollectionID: 0,
|
||||
PartitionIDs: nil,
|
||||
Dsl: "",
|
||||
PlaceholderGroup: nil,
|
||||
DslType: 0,
|
||||
SerializedExprPlan: nil,
|
||||
OutputFieldsId: nil,
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
},
|
||||
ctx: ctx,
|
||||
resultBuf: make(chan []*internalpb.SearchResults),
|
||||
result: nil,
|
||||
query: req,
|
||||
chMgr: chMgr,
|
||||
qc: qc,
|
||||
}
|
||||
|
||||
// simple mock for query node
|
||||
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
|
||||
|
||||
err = chMgr.createDQLStream(collectionID)
|
||||
assert.NoError(t, err)
|
||||
stream, err := chMgr.getDQLStream(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
consumeCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-consumeCtx.Done():
|
||||
return
|
||||
case pack := <-stream.Chan():
|
||||
for _, msg := range pack.Msgs {
|
||||
_, ok := msg.(*msgstream.SearchMsg)
|
||||
assert.True(t, ok)
|
||||
// TODO(dragondriver): construct result according to the request
|
||||
|
||||
constructSearchResulstData := func() *schemapb.SearchResultData {
|
||||
resultData := &schemapb.SearchResultData{
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
FieldsData: nil,
|
||||
Scores: make([]float32, nq*topk),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: make([]int64, nq*topk),
|
||||
},
|
||||
},
|
||||
},
|
||||
Topks: make([]int64, nq),
|
||||
}
|
||||
|
||||
// ids := make([]int64, topk)
|
||||
// for i := 0; i < topk; i++ {
|
||||
// ids[i] = int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
// }
|
||||
|
||||
for i := 0; i < nq; i++ {
|
||||
for j := 0; j < topk; j++ {
|
||||
offset := i*topk + j
|
||||
score := rand.Float32()
|
||||
id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
|
||||
resultData.Scores[offset] = score
|
||||
resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id
|
||||
}
|
||||
resultData.Topks[i] = int64(topk)
|
||||
}
|
||||
|
||||
return resultData
|
||||
}
|
||||
|
||||
result1 := &internalpb.SearchResults{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_SearchResult,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 0,
|
||||
},
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
ResultChannelID: "",
|
||||
MetricType: distance.L2,
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
SealedSegmentIDsSearched: nil,
|
||||
ChannelIDsSearched: nil,
|
||||
GlobalSealedSegmentIDs: nil,
|
||||
SlicedBlob: nil,
|
||||
SlicedNumCount: 1,
|
||||
SlicedOffset: 0,
|
||||
}
|
||||
resultData := constructSearchResulstData()
|
||||
sliceBlob, err := proto.Marshal(resultData)
|
||||
assert.NoError(t, err)
|
||||
result1.SlicedBlob = sliceBlob
|
||||
|
||||
// send search result
|
||||
task.resultBuf <- []*internalpb.SearchResults{result1}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
|
||||
package rootcoord
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
// system filed id:
|
||||
// 0: unique row id
|
||||
// 1: timestamp
|
||||
|
@ -19,9 +21,9 @@ package rootcoord
|
|||
// 102: ...
|
||||
|
||||
const (
|
||||
StartOfUserFieldID = 100
|
||||
RowIDField = 0
|
||||
TimeStampField = 1
|
||||
RowIDFieldName = "RowID"
|
||||
TimeStampFieldName = "Timestamp"
|
||||
StartOfUserFieldID = common.StartOfUserFieldID
|
||||
RowIDField = common.RowIDField
|
||||
TimeStampField = common.TimeStampField
|
||||
RowIDFieldName = common.RowIDFieldName
|
||||
TimeStampFieldName = common.TimeStampFieldName
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue