mirror of https://github.com/milvus-io/milvus.git
Add InsertTask and SearchTask to Proxy
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/4973/head^2
parent
7e182a230a
commit
b8ccbc1c97
|
@ -1,13 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/conf"
|
||||
|
|
|
@ -9,7 +9,6 @@ func NewEtcdKV() *kv.MemoryKV {
|
|||
return kv.NewMemoryKV()
|
||||
}
|
||||
|
||||
// use MemoryKV to mock EtcdKV
|
||||
func NewMemoryKV() *kv.MemoryKV {
|
||||
return kv.NewMemoryKV()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,23 +1,13 @@
|
|||
// Copyright 2020 TiKV Project Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package id
|
||||
|
||||
import (
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
"github.com/zilliztech/milvus-distributed/internal/master/tso"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
)
|
||||
|
||||
|
||||
type UniqueID = typeutil.UniqueID
|
||||
|
||||
// GlobalTSOAllocator is the global single point TSO allocator.
|
||||
|
@ -25,8 +15,16 @@ type GlobalIdAllocator struct {
|
|||
allocator tso.Allocator
|
||||
}
|
||||
|
||||
var allocator GlobalIdAllocator = GlobalIdAllocator{
|
||||
allocator: tso.NewGlobalTSOAllocator("idTimestamp"),
|
||||
var allocator *GlobalIdAllocator
|
||||
|
||||
func InitGlobalIdAllocator(key string, base kv.KVBase){
|
||||
allocator = NewGlobalIdAllocator(key, base)
|
||||
}
|
||||
|
||||
func NewGlobalIdAllocator(key string, base kv.KVBase) * GlobalIdAllocator{
|
||||
return &GlobalIdAllocator{
|
||||
allocator: tso.NewGlobalTSOAllocator( key, base),
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize will initialize the created global TSO allocator.
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package id
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv/mockkv"
|
||||
"os"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
var GIdAllocator *GlobalIdAllocator
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
GIdAllocator = NewGlobalIdAllocator("idTimestamp", mockkv.NewEtcdKV())
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestGlobalIdAllocator_Initialize(t *testing.T) {
|
||||
err := GIdAllocator.Initialize()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestGlobalIdAllocator_AllocOne(t *testing.T) {
|
||||
one, err := GIdAllocator.AllocOne()
|
||||
assert.Nil(t, err)
|
||||
ano, err := GIdAllocator.AllocOne()
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, one, ano)
|
||||
}
|
||||
|
||||
func TestGlobalIdAllocator_Alloc(t *testing.T) {
|
||||
count := uint32(2<<10)
|
||||
idStart, idEnd, err := GIdAllocator.Alloc(count)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, count, uint32(idEnd - idStart))
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
|
||||
"github.com/apache/pulsar-client-go/pulsar"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/zilliztech/milvus-distributed/internal/master/id"
|
||||
"github.com/zilliztech/milvus-distributed/internal/conf"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
"github.com/zilliztech/milvus-distributed/internal/master/controller"
|
||||
|
@ -58,7 +60,18 @@ type Master struct {
|
|||
closeCallbacks []func()
|
||||
}
|
||||
|
||||
func newKvBase() *kv.EtcdKV {
|
||||
func newTSOKVBase(subPath string) * kv.EtcdKV{
|
||||
etcdAddr := conf.Config.Etcd.Address
|
||||
etcdAddr += ":"
|
||||
etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10)
|
||||
client, _ := clientv3.New(clientv3.Config{
|
||||
Endpoints: []string{etcdAddr},
|
||||
DialTimeout: 5 * time.Second,
|
||||
})
|
||||
return kv.NewEtcdKV(client, path.Join(conf.Config.Etcd.Rootpath, subPath))
|
||||
}
|
||||
|
||||
func newKVBase() *kv.EtcdKV {
|
||||
etcdAddr := conf.Config.Etcd.Address
|
||||
etcdAddr += ":"
|
||||
etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10)
|
||||
|
@ -66,7 +79,6 @@ func newKvBase() *kv.EtcdKV {
|
|||
Endpoints: []string{etcdAddr},
|
||||
DialTimeout: 5 * time.Second,
|
||||
})
|
||||
// defer cli.Close()
|
||||
kvBase := kv.NewEtcdKV(cli, conf.Config.Etcd.Rootpath)
|
||||
return kvBase
|
||||
}
|
||||
|
@ -74,17 +86,15 @@ func newKvBase() *kv.EtcdKV {
|
|||
// CreateServer creates the UNINITIALIZED pd server with given configuration.
|
||||
func CreateServer(ctx context.Context) (*Master, error) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
id.InitGlobalIdAllocator("idTimestamp", newTSOKVBase("gid"))
|
||||
m := &Master{
|
||||
ctx: ctx,
|
||||
startTimestamp: time.Now().Unix(),
|
||||
kvBase: newKvBase(),
|
||||
kvBase: newKVBase(),
|
||||
ssChan: make(chan internalpb.SegmentStatistics, 10),
|
||||
pc: informer.NewPulsarClient(),
|
||||
tsoAllocator: tso.NewGlobalTSOAllocator("timestamp", newTSOKVBase("tso")),
|
||||
}
|
||||
etcdAddr := conf.Config.Etcd.Address
|
||||
etcdAddr += ":"
|
||||
etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10)
|
||||
m.tsoAllocator = tso.NewGlobalTSOAllocator("timestamp")
|
||||
m.grpcServer = grpc.NewServer()
|
||||
masterpb.RegisterMasterServer(m.grpcServer, m)
|
||||
return m, nil
|
||||
|
|
|
@ -1,32 +1,14 @@
|
|||
// Copyright 2020 TiKV Project Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tso
|
||||
|
||||
import (
|
||||
"log"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/conf"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
@ -51,26 +33,16 @@ type Allocator interface {
|
|||
|
||||
// GlobalTSOAllocator is the global single point TSO allocator.
|
||||
type GlobalTSOAllocator struct {
|
||||
timestampOracle *timestampOracle
|
||||
tso *timestampOracle
|
||||
}
|
||||
|
||||
// NewGlobalTSOAllocator creates a new global TSO allocator.
|
||||
func NewGlobalTSOAllocator(key string) Allocator {
|
||||
|
||||
etcdAddr := conf.Config.Etcd.Address
|
||||
etcdAddr += ":"
|
||||
etcdAddr += strconv.FormatInt(int64(conf.Config.Etcd.Port), 10)
|
||||
|
||||
client, _ := clientv3.New(clientv3.Config{
|
||||
Endpoints: []string{etcdAddr},
|
||||
DialTimeout: 5 * time.Second,
|
||||
})
|
||||
func NewGlobalTSOAllocator(key string, kvBase kv.KVBase) Allocator {
|
||||
|
||||
var saveInterval time.Duration = 3 * time.Second
|
||||
return &GlobalTSOAllocator{
|
||||
timestampOracle: ×tampOracle{
|
||||
kvBase: kv.NewEtcdKV(client, path.Join(conf.Config.Etcd.Rootpath, "tso")),
|
||||
rootPath: conf.Config.Etcd.Rootpath,
|
||||
tso: ×tampOracle{
|
||||
kvBase: kvBase,
|
||||
saveInterval: saveInterval,
|
||||
maxResetTSGap: func() time.Duration { return 3 * time.Second },
|
||||
key: key,
|
||||
|
@ -80,17 +52,17 @@ func NewGlobalTSOAllocator(key string) Allocator {
|
|||
|
||||
// Initialize will initialize the created global TSO allocator.
|
||||
func (gta *GlobalTSOAllocator) Initialize() error {
|
||||
return gta.timestampOracle.SyncTimestamp()
|
||||
return gta.tso.SyncTimestamp()
|
||||
}
|
||||
|
||||
// UpdateTSO is used to update the TSO in memory and the time window in etcd.
|
||||
func (gta *GlobalTSOAllocator) UpdateTSO() error {
|
||||
return gta.timestampOracle.UpdateTimestamp()
|
||||
return gta.tso.UpdateTimestamp()
|
||||
}
|
||||
|
||||
// SetTSO sets the physical part with given tso.
|
||||
func (gta *GlobalTSOAllocator) SetTSO(tso uint64) error {
|
||||
return gta.timestampOracle.ResetUserTimestamp(tso)
|
||||
return gta.tso.ResetUserTimestamp(tso)
|
||||
}
|
||||
|
||||
// GenerateTSO is used to generate a given number of TSOs.
|
||||
|
@ -104,7 +76,7 @@ func (gta *GlobalTSOAllocator) GenerateTSO(count uint32) (uint64, error) {
|
|||
maxRetryCount := 10
|
||||
|
||||
for i := 0; i < maxRetryCount; i++ {
|
||||
current := (*atomicObject)(atomic.LoadPointer(>a.timestampOracle.TSO))
|
||||
current := (*atomicObject)(atomic.LoadPointer(>a.tso.TSO))
|
||||
if current == nil || current.physical == typeutil.ZeroTime {
|
||||
// If it's leader, maybe SyncTimestamp hasn't completed yet
|
||||
log.Println("sync hasn't completed yet, wait for a while")
|
||||
|
@ -127,5 +99,5 @@ func (gta *GlobalTSOAllocator) GenerateTSO(count uint32) (uint64, error) {
|
|||
|
||||
// Reset is used to reset the TSO allocator.
|
||||
func (gta *GlobalTSOAllocator) Reset() {
|
||||
gta.timestampOracle.ResetTimestamp()
|
||||
gta.tso.ResetTimestamp()
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package tso
|
|||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/conf"
|
||||
"github.com/zilliztech/milvus-distributed/internal/kv/mockkv"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
|
||||
"os"
|
||||
|
@ -10,18 +9,10 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var GTsoAllocator *GlobalTSOAllocator
|
||||
var GTsoAllocator Allocator
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
GTsoAllocator = &GlobalTSOAllocator{
|
||||
timestampOracle: ×tampOracle{
|
||||
kvBase: mockkv.NewEtcdKV(),
|
||||
rootPath: conf.Config.Etcd.Rootpath,
|
||||
saveInterval: 3 * time.Second,
|
||||
maxResetTSGap: func() time.Duration { return 3 * time.Second },
|
||||
key: "tso",
|
||||
},
|
||||
}
|
||||
GTsoAllocator = NewGlobalTSOAllocator("timestamp", mockkv.NewEtcdKV())
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
@ -64,4 +55,4 @@ func TestGlobalTSOAllocator_UpdateTSO(t *testing.T) {
|
|||
|
||||
func TestGlobalTSOAllocator_Reset(t *testing.T) {
|
||||
GTsoAllocator.Reset()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,6 @@ type atomicObject struct {
|
|||
|
||||
// timestampOracle is used to maintain the logic of tso.
|
||||
type timestampOracle struct {
|
||||
rootPath string
|
||||
key string
|
||||
kvBase kv.KVBase
|
||||
|
||||
|
|
|
@ -3,9 +3,10 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"log"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
|
@ -14,7 +15,7 @@ import (
|
|||
|
||||
func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.IntegerRangeResponse, error) {
|
||||
it := &InsertTask{
|
||||
baseInsertTask: baseInsertTask{
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: in.HashKeys,
|
||||
},
|
||||
|
@ -53,10 +54,70 @@ func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.
|
|||
}
|
||||
|
||||
func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSchema) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
Reason: "",
|
||||
}, nil
|
||||
cct := &CreateCollectionTask{
|
||||
CreateCollectionRequest: internalpb.CreateCollectionRequest{
|
||||
MsgType: internalpb.MsgType_kCreateCollection,
|
||||
Schema: &commonpb.Blob{},
|
||||
// TODO: req_id, timestamp, proxy_id
|
||||
},
|
||||
masterClient: p.masterClient,
|
||||
done: make(chan error),
|
||||
resultChan: make(chan *commonpb.Status),
|
||||
}
|
||||
schemaBytes, _ := proto.Marshal(req)
|
||||
cct.CreateCollectionRequest.Schema.Value = schemaBytes
|
||||
cct.ctx, cct.cancel = context.WithCancel(ctx)
|
||||
defer cct.cancel()
|
||||
|
||||
var t task = cct
|
||||
p.taskSch.DdQueue.Enqueue(&t)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Print("create collection timeout!")
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
|
||||
Reason: "create collection timeout!",
|
||||
}, errors.New("create collection timeout!")
|
||||
case result := <-cct.resultChan:
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.QueryResult, error) {
|
||||
qt := &QueryTask{
|
||||
SearchRequest: internalpb.SearchRequest{
|
||||
MsgType: internalpb.MsgType_kSearch,
|
||||
Query: &commonpb.Blob{},
|
||||
// TODO: req_id, proxy_id, timestamp, result_channel_id
|
||||
},
|
||||
queryMsgStream: p.queryMsgStream,
|
||||
done: make(chan error),
|
||||
resultBuf: make(chan []*internalpb.SearchResult),
|
||||
resultChan: make(chan *servicepb.QueryResult),
|
||||
}
|
||||
qt.ctx, qt.cancel = context.WithCancel(ctx)
|
||||
queryBytes, _ := proto.Marshal(req)
|
||||
qt.SearchRequest.Query.Value = queryBytes
|
||||
defer qt.cancel()
|
||||
|
||||
var t task = qt
|
||||
p.taskSch.DqQueue.Enqueue(&t)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Print("query timeout!")
|
||||
return &servicepb.QueryResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
|
||||
Reason: "query timeout!",
|
||||
},
|
||||
}, errors.New("query timeout!")
|
||||
case result := <-qt.resultChan:
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) DropCollection(ctx context.Context, req *servicepb.CollectionName) (*commonpb.Status, error) {
|
||||
|
|
|
@ -3,6 +3,7 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
|
@ -30,6 +31,7 @@ type Proxy struct {
|
|||
taskSch *TaskScheduler
|
||||
manipulationMsgStream *msgstream.PulsarMsgStream
|
||||
queryMsgStream *msgstream.PulsarMsgStream
|
||||
queryResultMsgStream *msgstream.PulsarMsgStream
|
||||
}
|
||||
|
||||
func CreateProxy(ctx context.Context) (*Proxy, error) {
|
||||
|
@ -117,14 +119,60 @@ func (p *Proxy) connectMaster() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) receiveResultLoop() {
|
||||
queryResultBuf := make(map[UniqueID][]*internalpb.SearchResult)
|
||||
|
||||
for {
|
||||
msgPack := p.queryResultMsgStream.Consume()
|
||||
if msgPack == nil {
|
||||
continue
|
||||
}
|
||||
tsMsg := msgPack.Msgs[0]
|
||||
searchResultMsg, _ := (*tsMsg).(*msgstream.SearchResultMsg)
|
||||
reqId := UniqueID(searchResultMsg.GetReqId())
|
||||
_, ok := queryResultBuf[reqId]
|
||||
if !ok {
|
||||
queryResultBuf[reqId] = make([]*internalpb.SearchResult, 0)
|
||||
}
|
||||
queryResultBuf[reqId] = append(queryResultBuf[reqId], &searchResultMsg.SearchResult)
|
||||
if len(queryResultBuf[reqId]) == 4 {
|
||||
// TODO: use the number of query node instead
|
||||
t := p.taskSch.getTaskByReqId(reqId)
|
||||
qt := (*t).(*QueryTask)
|
||||
qt.resultBuf <- queryResultBuf[reqId]
|
||||
delete(queryResultBuf, reqId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) queryResultLoop() {
|
||||
defer p.proxyLoopWg.Done()
|
||||
p.queryResultMsgStream = &msgstream.PulsarMsgStream{}
|
||||
// TODO: config
|
||||
p.queryResultMsgStream.Start()
|
||||
|
||||
go p.receiveResultLoop()
|
||||
|
||||
ctx, cancel := context.WithCancel(p.proxyLoopCtx)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Print("proxy is closed...")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) startProxyLoop(ctx context.Context) {
|
||||
p.proxyLoopCtx, p.proxyLoopCancel = context.WithCancel(ctx)
|
||||
p.proxyLoopWg.Add(3)
|
||||
p.proxyLoopWg.Add(4)
|
||||
|
||||
p.connectMaster()
|
||||
|
||||
go p.grpcLoop()
|
||||
go p.pulsarMsgStreamLoop()
|
||||
go p.queryResultLoop()
|
||||
go p.scheduleLoop()
|
||||
}
|
||||
|
||||
|
|
|
@ -25,10 +25,10 @@ type task interface {
|
|||
Notify(err error)
|
||||
}
|
||||
|
||||
type baseInsertTask = msgstream.InsertMsg
|
||||
type BaseInsertTask = msgstream.InsertMsg
|
||||
|
||||
type InsertTask struct {
|
||||
baseInsertTask
|
||||
BaseInsertTask
|
||||
ts Timestamp
|
||||
done chan error
|
||||
resultChan chan *servicepb.IntegerRangeResponse
|
||||
|
@ -62,7 +62,7 @@ func (it *InsertTask) PreExecute() error {
|
|||
}
|
||||
|
||||
func (it *InsertTask) Execute() error {
|
||||
var tsMsg msgstream.TsMsg = it
|
||||
var tsMsg msgstream.TsMsg = &it.BaseInsertTask
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: it.BeginTs(),
|
||||
EndTs: it.EndTs(),
|
||||
|
@ -161,3 +161,132 @@ func (cct *CreateCollectionTask) WaitToFinish() error {
|
|||
func (cct *CreateCollectionTask) Notify(err error) {
|
||||
cct.done <- err
|
||||
}
|
||||
|
||||
type QueryTask struct {
|
||||
internalpb.SearchRequest
|
||||
queryMsgStream *msgstream.PulsarMsgStream
|
||||
done chan error
|
||||
resultBuf chan []*internalpb.SearchResult
|
||||
resultChan chan *servicepb.QueryResult
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (qt *QueryTask) Id() UniqueID {
|
||||
return qt.ReqId
|
||||
}
|
||||
|
||||
func (qt *QueryTask) Type() internalpb.MsgType {
|
||||
return qt.MsgType
|
||||
}
|
||||
|
||||
func (qt *QueryTask) BeginTs() Timestamp {
|
||||
return qt.Timestamp
|
||||
}
|
||||
|
||||
func (qt *QueryTask) EndTs() Timestamp {
|
||||
return qt.Timestamp
|
||||
}
|
||||
|
||||
func (qt *QueryTask) SetTs(ts Timestamp) {
|
||||
qt.Timestamp = ts
|
||||
}
|
||||
|
||||
func (qt *QueryTask) PreExecute() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qt *QueryTask) Execute() error {
|
||||
var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
|
||||
SearchRequest: qt.SearchRequest,
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: qt.Timestamp,
|
||||
EndTimestamp: qt.Timestamp,
|
||||
},
|
||||
}
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: qt.Timestamp,
|
||||
EndTs: qt.Timestamp,
|
||||
Msgs: make([]*msgstream.TsMsg, 1),
|
||||
}
|
||||
msgPack.Msgs[0] = &tsMsg
|
||||
qt.queryMsgStream.Produce(msgPack)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qt *QueryTask) PostExecute() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qt *QueryTask) WaitToFinish() error {
|
||||
defer qt.cancel()
|
||||
for {
|
||||
select {
|
||||
case err := <-qt.done:
|
||||
return err
|
||||
case <-qt.ctx.Done():
|
||||
log.Print("wait to finish failed, timeout!")
|
||||
return errors.New("wait to finish failed, timeout!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (qt *QueryTask) Notify(err error) {
|
||||
defer qt.cancel()
|
||||
defer func() {
|
||||
qt.done <- err
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-qt.ctx.Done():
|
||||
log.Print("wait to finish failed, timeout!")
|
||||
return
|
||||
case searchResults := <-qt.resultBuf:
|
||||
rlen := len(searchResults) // query num
|
||||
if rlen <= 0 {
|
||||
qt.resultChan <- &servicepb.QueryResult{}
|
||||
return
|
||||
}
|
||||
n := len(searchResults[0].Hits) // n
|
||||
if n <= 0 {
|
||||
qt.resultChan <- &servicepb.QueryResult{}
|
||||
return
|
||||
}
|
||||
k := len(searchResults[0].Hits[0].Ids) // k
|
||||
queryResult := &servicepb.QueryResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
},
|
||||
}
|
||||
// reduce by score, TODO: use better algorithm
|
||||
// use merge-sort here, the number of ways to merge is `rlen`
|
||||
// in this process, we must make sure:
|
||||
// len(queryResult.Hits) == n
|
||||
// len(queryResult.Hits[i].Ids) == k for i in range(n)
|
||||
for i := 0; i < n; n++ { // n
|
||||
locs := make([]int, rlen)
|
||||
hits := &servicepb.Hits{}
|
||||
for j := 0; j < k; j++ { // k
|
||||
choice, maxScore := 0, float32(0)
|
||||
for q, loc := range locs { // query num, the number of ways to merge
|
||||
score := func(score *servicepb.Score) float32 {
|
||||
// TODO: get score of root
|
||||
return 0.0
|
||||
}(searchResults[q].Hits[i].Scores[loc])
|
||||
if score > maxScore {
|
||||
choice = q
|
||||
maxScore = score
|
||||
}
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
hits.Ids = append(hits.Ids, searchResults[choice].Hits[i].Ids[choiceOffset])
|
||||
hits.RowData = append(hits.RowData, searchResults[choice].Hits[i].RowData[choiceOffset])
|
||||
hits.Scores = append(hits.Scores, searchResults[choice].Hits[i].Scores[choiceOffset])
|
||||
locs[choice]++
|
||||
}
|
||||
queryResult.Hits = append(queryResult.Hits, hits)
|
||||
}
|
||||
qt.resultChan <- queryResult
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,27 +7,14 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
type baseTaskQueue struct {
|
||||
type BaseTaskQueue struct {
|
||||
unissuedTasks *list.List
|
||||
activeTasks map[Timestamp]*task
|
||||
utLock sync.Mutex
|
||||
atLock sync.Mutex
|
||||
}
|
||||
|
||||
type ddTaskQueue struct {
|
||||
baseTaskQueue
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
type dmTaskQueue struct {
|
||||
baseTaskQueue
|
||||
}
|
||||
|
||||
type dqTaskQueue struct {
|
||||
baseTaskQueue
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) Empty() bool {
|
||||
func (queue *BaseTaskQueue) Empty() bool {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
queue.atLock.Lock()
|
||||
|
@ -35,13 +22,13 @@ func (queue *baseTaskQueue) Empty() bool {
|
|||
return queue.unissuedTasks.Len() <= 0 && len(queue.activeTasks) <= 0
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) AddUnissuedTask(t *task) {
|
||||
func (queue *BaseTaskQueue) AddUnissuedTask(t *task) {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
queue.unissuedTasks.PushBack(t)
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) FrontUnissuedTask() *task {
|
||||
func (queue *BaseTaskQueue) FrontUnissuedTask() *task {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
if queue.unissuedTasks.Len() <= 0 {
|
||||
|
@ -51,7 +38,7 @@ func (queue *baseTaskQueue) FrontUnissuedTask() *task {
|
|||
return queue.unissuedTasks.Front().Value.(*task)
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) PopUnissuedTask() *task {
|
||||
func (queue *BaseTaskQueue) PopUnissuedTask() *task {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
if queue.unissuedTasks.Len() <= 0 {
|
||||
|
@ -62,7 +49,7 @@ func (queue *baseTaskQueue) PopUnissuedTask() *task {
|
|||
return queue.unissuedTasks.Remove(ft).(*task)
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) AddActiveTask(t *task) {
|
||||
func (queue *BaseTaskQueue) AddActiveTask(t *task) {
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Lock()
|
||||
ts := (*t).EndTs()
|
||||
|
@ -73,7 +60,7 @@ func (queue *baseTaskQueue) AddActiveTask(t *task) {
|
|||
queue.activeTasks[ts] = t
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) PopActiveTask(ts Timestamp) *task {
|
||||
func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) *task {
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Lock()
|
||||
t, ok := queue.activeTasks[ts]
|
||||
|
@ -85,7 +72,27 @@ func (queue *baseTaskQueue) PopActiveTask(ts Timestamp) *task {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (queue *baseTaskQueue) TaskDoneTest(ts Timestamp) bool {
|
||||
func (queue *BaseTaskQueue) getTaskByReqId(reqId UniqueID) *task {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Lock()
|
||||
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
|
||||
if (*(e.Value.(*task))).Id() == reqId {
|
||||
return e.Value.(*task)
|
||||
}
|
||||
}
|
||||
|
||||
queue.atLock.Lock()
|
||||
defer queue.atLock.Unlock()
|
||||
for ats := range queue.activeTasks {
|
||||
if (*(queue.activeTasks[ats])).Id() == reqId {
|
||||
return queue.activeTasks[ats]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool {
|
||||
queue.utLock.Lock()
|
||||
defer queue.utLock.Unlock()
|
||||
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
|
||||
|
@ -105,6 +112,19 @@ func (queue *baseTaskQueue) TaskDoneTest(ts Timestamp) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
type ddTaskQueue struct {
|
||||
BaseTaskQueue
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
type dmTaskQueue struct {
|
||||
BaseTaskQueue
|
||||
}
|
||||
|
||||
type dqTaskQueue struct {
|
||||
BaseTaskQueue
|
||||
}
|
||||
|
||||
func (queue *ddTaskQueue) Enqueue(t *task) error {
|
||||
queue.lock.Lock()
|
||||
defer queue.lock.Unlock()
|
||||
|
@ -148,9 +168,49 @@ func (sched *TaskScheduler) scheduleDqTask() *task {
|
|||
return sched.DqQueue.PopUnissuedTask()
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) getTaskByReqId(reqId UniqueID) *task {
|
||||
if t := sched.DdQueue.getTaskByReqId(reqId); t != nil {
|
||||
return t
|
||||
}
|
||||
if t := sched.DmQueue.getTaskByReqId(reqId); t != nil {
|
||||
return t
|
||||
}
|
||||
if t := sched.DqQueue.getTaskByReqId(reqId); t != nil {
|
||||
return t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) definitionLoop() {
|
||||
defer sched.wg.Done()
|
||||
defer sched.cancel()
|
||||
|
||||
for {
|
||||
if sched.DdQueue.Empty() {
|
||||
continue
|
||||
}
|
||||
|
||||
//sched.DdQueue.atLock.Lock()
|
||||
t := sched.scheduleDdTask()
|
||||
|
||||
err := (*t).PreExecute()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = (*t).Execute()
|
||||
if err != nil {
|
||||
log.Printf("execute definition task failed, error = %v", err)
|
||||
}
|
||||
(*t).Notify(err)
|
||||
|
||||
sched.DdQueue.AddActiveTask(t)
|
||||
//sched.DdQueue.atLock.Unlock()
|
||||
|
||||
(*t).WaitToFinish()
|
||||
(*t).PostExecute()
|
||||
|
||||
sched.DdQueue.PopActiveTask((*t).EndTs())
|
||||
}
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) manipulationLoop() {
|
||||
|
@ -193,6 +253,38 @@ func (sched *TaskScheduler) manipulationLoop() {
|
|||
func (sched *TaskScheduler) queryLoop() {
|
||||
defer sched.wg.Done()
|
||||
defer sched.cancel()
|
||||
|
||||
for {
|
||||
if sched.DqQueue.Empty() {
|
||||
continue
|
||||
}
|
||||
|
||||
sched.DqQueue.atLock.Lock()
|
||||
t := sched.scheduleDqTask()
|
||||
|
||||
if err := (*t).PreExecute(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := (*t).Execute()
|
||||
if err != nil {
|
||||
log.Printf("execute query task failed, error = %v", err)
|
||||
}
|
||||
(*t).Notify(err)
|
||||
}()
|
||||
|
||||
sched.DqQueue.AddActiveTask(t)
|
||||
sched.DqQueue.atLock.Unlock()
|
||||
|
||||
go func() {
|
||||
(*t).WaitToFinish()
|
||||
(*t).PostExecute()
|
||||
|
||||
// remove from active list
|
||||
sched.DqQueue.PopActiveTask((*t).EndTs())
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (sched *TaskScheduler) Start(ctx context.Context) error {
|
||||
|
|
Loading…
Reference in New Issue