mirror of https://github.com/milvus-io/milvus.git
312 lines
7.3 KiB
Go
312 lines
7.3 KiB
Go
package streamrpc
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"sync"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
|
)
|
|
|
|
type QueryStreamServer interface {
|
|
Send(*internalpb.RetrieveResults) error
|
|
Context() context.Context
|
|
}
|
|
type QueryStreamClient interface {
|
|
Recv() (*internalpb.RetrieveResults, error)
|
|
Context() context.Context
|
|
CloseSend() error
|
|
}
|
|
|
|
type ConcurrentQueryStreamServer struct {
|
|
server QueryStreamServer
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (s *ConcurrentQueryStreamServer) Send(result *internalpb.RetrieveResults) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
return s.server.Send(result)
|
|
}
|
|
|
|
func (s *ConcurrentQueryStreamServer) Context() context.Context {
|
|
return s.server.Context()
|
|
}
|
|
|
|
func NewConcurrentQueryStreamServer(srv QueryStreamServer) *ConcurrentQueryStreamServer {
|
|
return &ConcurrentQueryStreamServer{
|
|
server: srv,
|
|
mu: sync.Mutex{},
|
|
}
|
|
}
|
|
|
|
type RetrieveResultCache struct {
|
|
result *internalpb.RetrieveResults
|
|
size int
|
|
cap int
|
|
}
|
|
|
|
func (c *RetrieveResultCache) Put(result *internalpb.RetrieveResults) {
|
|
if c.result == nil {
|
|
c.result = result
|
|
c.size = proto.Size(result)
|
|
return
|
|
}
|
|
|
|
c.merge(result)
|
|
}
|
|
|
|
func (c *RetrieveResultCache) Flush() *internalpb.RetrieveResults {
|
|
result := c.result
|
|
c.result = nil
|
|
c.size = 0
|
|
return result
|
|
}
|
|
|
|
func (c *RetrieveResultCache) Alloc(result *internalpb.RetrieveResults) bool {
|
|
return proto.Size(result)+c.size <= c.cap
|
|
}
|
|
|
|
func (c *RetrieveResultCache) IsFull() bool {
|
|
return c.size > c.cap
|
|
}
|
|
|
|
func (c *RetrieveResultCache) IsEmpty() bool {
|
|
return c.size == 0
|
|
}
|
|
|
|
func (c *RetrieveResultCache) merge(result *internalpb.RetrieveResults) {
|
|
switch result.GetIds().GetIdField().(type) {
|
|
case *schemapb.IDs_IntId:
|
|
c.result.GetIds().GetIntId().Data = append(c.result.GetIds().GetIntId().GetData(), result.GetIds().GetIntId().GetData()...)
|
|
case *schemapb.IDs_StrId:
|
|
c.result.GetIds().GetStrId().Data = append(c.result.GetIds().GetStrId().GetData(), result.GetIds().GetStrId().GetData()...)
|
|
}
|
|
c.result.AllRetrieveCount = c.result.AllRetrieveCount + result.AllRetrieveCount
|
|
c.result.CostAggregation = mergeCostAggregation(c.result.GetCostAggregation(), result.GetCostAggregation())
|
|
c.size = proto.Size(c.result)
|
|
}
|
|
|
|
func mergeCostAggregation(a *internalpb.CostAggregation, b *internalpb.CostAggregation) *internalpb.CostAggregation {
|
|
if a == nil {
|
|
return b
|
|
}
|
|
if b == nil {
|
|
return a
|
|
}
|
|
|
|
return &internalpb.CostAggregation{
|
|
ResponseTime: a.GetResponseTime() + b.GetResponseTime(),
|
|
ServiceTime: a.GetServiceTime() + b.GetServiceTime(),
|
|
TotalNQ: a.GetTotalNQ(),
|
|
TotalRelatedDataSize: a.GetTotalRelatedDataSize() + b.GetTotalRelatedDataSize(),
|
|
}
|
|
}
|
|
|
|
// Merge result by size and time.
|
|
type ResultCacheServer struct {
|
|
mu sync.Mutex
|
|
srv QueryStreamServer
|
|
cache *RetrieveResultCache
|
|
maxMsgSize int
|
|
}
|
|
|
|
func NewResultCacheServer(srv QueryStreamServer, cap int, maxMsgSize int) *ResultCacheServer {
|
|
return &ResultCacheServer{
|
|
srv: srv,
|
|
cache: &RetrieveResultCache{cap: cap},
|
|
maxMsgSize: maxMsgSize,
|
|
}
|
|
}
|
|
|
|
func (s *ResultCacheServer) splitMsgToMaxSize(result *internalpb.RetrieveResults) []*internalpb.RetrieveResults {
|
|
newpks := make([]*schemapb.IDs, 0)
|
|
switch result.GetIds().GetIdField().(type) {
|
|
case *schemapb.IDs_IntId:
|
|
pks := result.GetIds().GetIntId().Data
|
|
batch := s.maxMsgSize / 8
|
|
print(batch)
|
|
for start := 0; start < len(pks); start += batch {
|
|
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: pks[start:min(start+batch, len(pks))]}}})
|
|
}
|
|
|
|
case *schemapb.IDs_StrId:
|
|
pks := result.GetIds().GetStrId().Data
|
|
start := 0
|
|
size := 0
|
|
for i, pk := range pks {
|
|
if size+len(pk) > s.maxMsgSize {
|
|
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: pks[start:i]}}})
|
|
start = i
|
|
size = 0
|
|
}
|
|
size += len(pk)
|
|
}
|
|
if size > 0 {
|
|
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: pks[start:]}}})
|
|
}
|
|
}
|
|
|
|
results := make([]*internalpb.RetrieveResults, len(newpks))
|
|
for i, pks := range newpks {
|
|
results[i] = &internalpb.RetrieveResults{
|
|
Status: merr.Status(nil),
|
|
Ids: pks,
|
|
}
|
|
}
|
|
results[len(results)-1].AllRetrieveCount = result.AllRetrieveCount
|
|
results[len(results)-1].CostAggregation = result.CostAggregation
|
|
return results
|
|
}
|
|
|
|
func (s *ResultCacheServer) Send(result *internalpb.RetrieveResults) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if !s.cache.Alloc(result) && !s.cache.IsEmpty() {
|
|
result := s.cache.Flush()
|
|
if err := s.srv.Send(result); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
s.cache.Put(result)
|
|
if s.cache.IsFull() && s.cache.size <= s.maxMsgSize {
|
|
result := s.cache.Flush()
|
|
if err := s.srv.Send(result); err != nil {
|
|
return err
|
|
}
|
|
} else if s.cache.IsFull() && s.cache.size > s.maxMsgSize {
|
|
results := s.splitMsgToMaxSize(s.cache.Flush())
|
|
if proto.Size(results[len(results)-1]) < s.cache.cap {
|
|
s.cache.Put(results[len(results)-1])
|
|
results = results[:len(results)-1]
|
|
}
|
|
|
|
for _, result := range results {
|
|
if err := s.srv.Send(result); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *ResultCacheServer) Flush() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
result := s.cache.Flush()
|
|
if result == nil {
|
|
return nil
|
|
}
|
|
|
|
if err := s.srv.Send(result); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *ResultCacheServer) Context() context.Context {
|
|
return s.srv.Context()
|
|
}
|
|
|
|
// TODO LOCAL SERVER AND CLIENT FOR STANDALONE
|
|
// ONLY FOR TEST
|
|
type LocalQueryServer struct {
|
|
grpc.ServerStream
|
|
|
|
resultCh chan *internalpb.RetrieveResults
|
|
ctx context.Context
|
|
|
|
finishOnce sync.Once
|
|
errCh chan error
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (s *LocalQueryServer) Send(result *internalpb.RetrieveResults) error {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return s.ctx.Err()
|
|
default:
|
|
s.resultCh <- result
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *LocalQueryServer) FinishError() error {
|
|
return <-s.errCh
|
|
}
|
|
|
|
func (s *LocalQueryServer) Context() context.Context {
|
|
return s.ctx
|
|
}
|
|
|
|
func (s *LocalQueryServer) FinishSend(err error) error {
|
|
s.finishOnce.Do(func() {
|
|
close(s.resultCh)
|
|
if err != nil {
|
|
s.errCh <- err
|
|
} else {
|
|
s.errCh <- io.EOF
|
|
}
|
|
})
|
|
return nil
|
|
}
|
|
|
|
type LocalQueryClient struct {
|
|
grpc.ClientStream
|
|
|
|
server *LocalQueryServer
|
|
resultCh chan *internalpb.RetrieveResults
|
|
ctx context.Context
|
|
}
|
|
|
|
func (s *LocalQueryClient) RecvMsg(m interface{}) error {
|
|
// TODO implement me
|
|
panic("implement me")
|
|
}
|
|
|
|
func (s *LocalQueryClient) Recv() (*internalpb.RetrieveResults, error) {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return nil, s.ctx.Err()
|
|
default:
|
|
result, ok := <-s.resultCh
|
|
if !ok {
|
|
return nil, s.server.FinishError()
|
|
}
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
func (s *LocalQueryClient) Context() context.Context {
|
|
return s.ctx
|
|
}
|
|
|
|
func (s *LocalQueryClient) CloseSend() error {
|
|
return nil
|
|
}
|
|
|
|
func (s *LocalQueryClient) CreateServer() *LocalQueryServer {
|
|
s.server = &LocalQueryServer{
|
|
resultCh: s.resultCh,
|
|
ctx: s.ctx,
|
|
mu: sync.Mutex{},
|
|
errCh: make(chan error, 1),
|
|
}
|
|
return s.server
|
|
}
|
|
|
|
func NewLocalQueryClient(ctx context.Context) *LocalQueryClient {
|
|
return &LocalQueryClient{
|
|
resultCh: make(chan *internalpb.RetrieveResults, 64),
|
|
ctx: ctx,
|
|
}
|
|
}
|