milvus/internal/util/streamrpc/streamer.go

312 lines
7.3 KiB
Go

package streamrpc
import (
"context"
"io"
"sync"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type QueryStreamServer interface {
Send(*internalpb.RetrieveResults) error
Context() context.Context
}
type QueryStreamClient interface {
Recv() (*internalpb.RetrieveResults, error)
Context() context.Context
CloseSend() error
}
type ConcurrentQueryStreamServer struct {
server QueryStreamServer
mu sync.Mutex
}
func (s *ConcurrentQueryStreamServer) Send(result *internalpb.RetrieveResults) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.server.Send(result)
}
func (s *ConcurrentQueryStreamServer) Context() context.Context {
return s.server.Context()
}
func NewConcurrentQueryStreamServer(srv QueryStreamServer) *ConcurrentQueryStreamServer {
return &ConcurrentQueryStreamServer{
server: srv,
mu: sync.Mutex{},
}
}
type RetrieveResultCache struct {
result *internalpb.RetrieveResults
size int
cap int
}
func (c *RetrieveResultCache) Put(result *internalpb.RetrieveResults) {
if c.result == nil {
c.result = result
c.size = proto.Size(result)
return
}
c.merge(result)
}
func (c *RetrieveResultCache) Flush() *internalpb.RetrieveResults {
result := c.result
c.result = nil
c.size = 0
return result
}
func (c *RetrieveResultCache) Alloc(result *internalpb.RetrieveResults) bool {
return proto.Size(result)+c.size <= c.cap
}
func (c *RetrieveResultCache) IsFull() bool {
return c.size > c.cap
}
func (c *RetrieveResultCache) IsEmpty() bool {
return c.size == 0
}
func (c *RetrieveResultCache) merge(result *internalpb.RetrieveResults) {
switch result.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
c.result.GetIds().GetIntId().Data = append(c.result.GetIds().GetIntId().GetData(), result.GetIds().GetIntId().GetData()...)
case *schemapb.IDs_StrId:
c.result.GetIds().GetStrId().Data = append(c.result.GetIds().GetStrId().GetData(), result.GetIds().GetStrId().GetData()...)
}
c.result.AllRetrieveCount = c.result.AllRetrieveCount + result.AllRetrieveCount
c.result.CostAggregation = mergeCostAggregation(c.result.GetCostAggregation(), result.GetCostAggregation())
c.size = proto.Size(c.result)
}
func mergeCostAggregation(a *internalpb.CostAggregation, b *internalpb.CostAggregation) *internalpb.CostAggregation {
if a == nil {
return b
}
if b == nil {
return a
}
return &internalpb.CostAggregation{
ResponseTime: a.GetResponseTime() + b.GetResponseTime(),
ServiceTime: a.GetServiceTime() + b.GetServiceTime(),
TotalNQ: a.GetTotalNQ(),
TotalRelatedDataSize: a.GetTotalRelatedDataSize() + b.GetTotalRelatedDataSize(),
}
}
// Merge result by size and time.
type ResultCacheServer struct {
mu sync.Mutex
srv QueryStreamServer
cache *RetrieveResultCache
maxMsgSize int
}
func NewResultCacheServer(srv QueryStreamServer, cap int, maxMsgSize int) *ResultCacheServer {
return &ResultCacheServer{
srv: srv,
cache: &RetrieveResultCache{cap: cap},
maxMsgSize: maxMsgSize,
}
}
func (s *ResultCacheServer) splitMsgToMaxSize(result *internalpb.RetrieveResults) []*internalpb.RetrieveResults {
newpks := make([]*schemapb.IDs, 0)
switch result.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
pks := result.GetIds().GetIntId().Data
batch := s.maxMsgSize / 8
print(batch)
for start := 0; start < len(pks); start += batch {
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: pks[start:min(start+batch, len(pks))]}}})
}
case *schemapb.IDs_StrId:
pks := result.GetIds().GetStrId().Data
start := 0
size := 0
for i, pk := range pks {
if size+len(pk) > s.maxMsgSize {
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: pks[start:i]}}})
start = i
size = 0
}
size += len(pk)
}
if size > 0 {
newpks = append(newpks, &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: pks[start:]}}})
}
}
results := make([]*internalpb.RetrieveResults, len(newpks))
for i, pks := range newpks {
results[i] = &internalpb.RetrieveResults{
Status: merr.Status(nil),
Ids: pks,
}
}
results[len(results)-1].AllRetrieveCount = result.AllRetrieveCount
results[len(results)-1].CostAggregation = result.CostAggregation
return results
}
func (s *ResultCacheServer) Send(result *internalpb.RetrieveResults) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.cache.Alloc(result) && !s.cache.IsEmpty() {
result := s.cache.Flush()
if err := s.srv.Send(result); err != nil {
return err
}
}
s.cache.Put(result)
if s.cache.IsFull() && s.cache.size <= s.maxMsgSize {
result := s.cache.Flush()
if err := s.srv.Send(result); err != nil {
return err
}
} else if s.cache.IsFull() && s.cache.size > s.maxMsgSize {
results := s.splitMsgToMaxSize(s.cache.Flush())
if proto.Size(results[len(results)-1]) < s.cache.cap {
s.cache.Put(results[len(results)-1])
results = results[:len(results)-1]
}
for _, result := range results {
if err := s.srv.Send(result); err != nil {
return err
}
}
}
return nil
}
func (s *ResultCacheServer) Flush() error {
s.mu.Lock()
defer s.mu.Unlock()
result := s.cache.Flush()
if result == nil {
return nil
}
if err := s.srv.Send(result); err != nil {
return err
}
return nil
}
func (s *ResultCacheServer) Context() context.Context {
return s.srv.Context()
}
// TODO LOCAL SERVER AND CLIENT FOR STANDALONE
// ONLY FOR TEST
type LocalQueryServer struct {
grpc.ServerStream
resultCh chan *internalpb.RetrieveResults
ctx context.Context
finishOnce sync.Once
errCh chan error
mu sync.Mutex
}
func (s *LocalQueryServer) Send(result *internalpb.RetrieveResults) error {
select {
case <-s.ctx.Done():
return s.ctx.Err()
default:
s.resultCh <- result
return nil
}
}
func (s *LocalQueryServer) FinishError() error {
return <-s.errCh
}
func (s *LocalQueryServer) Context() context.Context {
return s.ctx
}
func (s *LocalQueryServer) FinishSend(err error) error {
s.finishOnce.Do(func() {
close(s.resultCh)
if err != nil {
s.errCh <- err
} else {
s.errCh <- io.EOF
}
})
return nil
}
type LocalQueryClient struct {
grpc.ClientStream
server *LocalQueryServer
resultCh chan *internalpb.RetrieveResults
ctx context.Context
}
func (s *LocalQueryClient) RecvMsg(m interface{}) error {
// TODO implement me
panic("implement me")
}
func (s *LocalQueryClient) Recv() (*internalpb.RetrieveResults, error) {
select {
case <-s.ctx.Done():
return nil, s.ctx.Err()
default:
result, ok := <-s.resultCh
if !ok {
return nil, s.server.FinishError()
}
return result, nil
}
}
func (s *LocalQueryClient) Context() context.Context {
return s.ctx
}
func (s *LocalQueryClient) CloseSend() error {
return nil
}
func (s *LocalQueryClient) CreateServer() *LocalQueryServer {
s.server = &LocalQueryServer{
resultCh: s.resultCh,
ctx: s.ctx,
mu: sync.Mutex{},
errCh: make(chan error, 1),
}
return s.server
}
func NewLocalQueryClient(ctx context.Context) *LocalQueryClient {
return &LocalQueryClient{
resultCh: make(chan *internalpb.RetrieveResults, 64),
ctx: ctx,
}
}