mirror of https://github.com/milvus-io/milvus.git
510 lines
14 KiB
Go
510 lines
14 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/zilliztech/milvus-distributed/internal/master/collection"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
|
mpb "github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
|
|
pb "github.com/zilliztech/milvus-distributed/internal/proto/message"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
|
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
|
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
|
etcd "go.etcd.io/etcd/clientv3"
|
|
"go.uber.org/atomic"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
const (
|
|
keyCollectionPath = "collection"
|
|
keySegmentPath = "segment"
|
|
)
|
|
|
|
type proxyServer struct {
|
|
servicepb.UnimplementedMilvusServiceServer
|
|
address string
|
|
masterAddress string
|
|
rootPath string // etcd root path
|
|
pulsarAddr string // pulsar address for reader
|
|
readerTopics []string //reader topics
|
|
deleteTopic string
|
|
queryTopic string
|
|
resultTopic string
|
|
resultGroup string
|
|
numReaderNode int
|
|
proxyId int64
|
|
getTimestamp func(count uint32) ([]typeutil.Timestamp, error)
|
|
client *etcd.Client
|
|
ctx context.Context
|
|
wg sync.WaitGroup
|
|
////////////////////////////////////////////////////////////////
|
|
masterConn *grpc.ClientConn
|
|
masterClient mpb.MasterClient
|
|
grpcServer *grpc.Server
|
|
reqSch *requestScheduler
|
|
///////////////////////////////////////////////////////////////
|
|
collectionList map[int64]*etcdpb.CollectionMeta
|
|
nameCollectionId map[string]int64
|
|
segmentList map[int64]*etcdpb.SegmentMeta
|
|
collectionMux sync.Mutex
|
|
queryId atomic.Int64
|
|
}
|
|
|
|
func (s *proxyServer) CreateCollection(ctx context.Context, req *schemapb.CollectionSchema) (*commonpb.Status, error) {
|
|
return &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) DropCollection(ctx context.Context, req *servicepb.CollectionName) (*commonpb.Status, error) {
|
|
return &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) HasCollection(ctx context.Context, req *servicepb.CollectionName) (*servicepb.BoolResponse, error) {
|
|
return &servicepb.BoolResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
},
|
|
Value: true,
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) DescribeCollection(ctx context.Context, req *servicepb.CollectionName) (*servicepb.CollectionDescription, error) {
|
|
return &servicepb.CollectionDescription{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) ShowCollections(ctx context.Context, req *commonpb.Empty) (*servicepb.StringListResponse, error) {
|
|
return &servicepb.StringListResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) CreatePartition(ctx context.Context, in *servicepb.PartitionName) (*commonpb.Status, error) {
|
|
return &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) DropPartition(ctx context.Context, in *servicepb.PartitionName) (*commonpb.Status, error) {
|
|
return &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) HasPartition(ctx context.Context, in *servicepb.PartitionName) (*servicepb.BoolResponse, error) {
|
|
return &servicepb.BoolResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
},
|
|
Value: true,
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) DescribePartition(ctx context.Context, in *servicepb.PartitionName) (*servicepb.PartitionDescription, error) {
|
|
return &servicepb.PartitionDescription{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) ShowPartitions(ctx context.Context, req *servicepb.CollectionName) (*servicepb.StringListResponse, error) {
|
|
return &servicepb.StringListResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
Reason: "",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) DeleteByID(ctx context.Context, req *pb.DeleteByIDParam) (*commonpb.Status, error) {
|
|
log.Printf("delete entites, total = %d", len(req.IdArray))
|
|
mReqMsg := pb.ManipulationReqMsg{
|
|
CollectionName: req.CollectionName,
|
|
ReqType: pb.ReqType_kDeleteEntityByID,
|
|
ProxyId: s.proxyId,
|
|
}
|
|
for _, id := range req.IdArray {
|
|
mReqMsg.PrimaryKeys = append(mReqMsg.PrimaryKeys, id)
|
|
}
|
|
if len(mReqMsg.PrimaryKeys) > 1 {
|
|
mReq := &manipulationReq{
|
|
stats: make([]commonpb.Status, 1),
|
|
msgs: append([]*pb.ManipulationReqMsg{}, &mReqMsg),
|
|
proxy: s,
|
|
}
|
|
if st := mReq.PreExecute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &st, nil
|
|
}
|
|
if st := mReq.Execute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &st, nil
|
|
}
|
|
if st := mReq.PostExecute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &st, nil
|
|
}
|
|
if st := mReq.WaitToFinish(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &st, nil
|
|
}
|
|
}
|
|
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, nil
|
|
}
|
|
|
|
func (s *proxyServer) Insert(ctx context.Context, req *servicepb.RowBatch) (*servicepb.IntegerRangeResponse, error) {
|
|
log.Printf("Insert Entities, total = %d", len(req.RowData))
|
|
msgMap := make(map[uint32]*pb.ManipulationReqMsg)
|
|
|
|
//TODO check collection schema's auto_id
|
|
if len(req.RowData) == 0 { //primary key is empty, set primary key by server
|
|
log.Printf("Set primary key")
|
|
}
|
|
if len(req.HashKeys) != len(req.RowData) {
|
|
return &servicepb.IntegerRangeResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
|
|
Reason: fmt.Sprintf("length of EntityIdArray not equal to lenght of RowsData"),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
for i := 0; i < len(req.HashKeys); i++ {
|
|
key := int64(req.HashKeys[i])
|
|
hash, err := typeutil.Hash32Int64(key)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Unknown, "hash failed on %d", key)
|
|
}
|
|
hash = hash % uint32(len(s.readerTopics))
|
|
ipm, ok := msgMap[hash]
|
|
if !ok {
|
|
segId, err := s.getSegmentId(int32(hash), req.CollectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
msgMap[hash] = &pb.ManipulationReqMsg{
|
|
CollectionName: req.CollectionName,
|
|
PartitionTag: req.PartitionTag,
|
|
SegmentId: segId,
|
|
ChannelId: int64(hash),
|
|
ReqType: pb.ReqType_kInsert,
|
|
ProxyId: s.proxyId,
|
|
//ExtraParams: req.ExtraParams,
|
|
}
|
|
ipm = msgMap[hash]
|
|
}
|
|
ipm.PrimaryKeys = append(ipm.PrimaryKeys, key)
|
|
ipm.RowsData = append(ipm.RowsData, &pb.RowData{Blob: req.RowData[i].Value}) // czs_tag
|
|
}
|
|
|
|
// TODO: alloc manipulation request id
|
|
mReq := manipulationReq{
|
|
stats: make([]commonpb.Status, len(msgMap)),
|
|
msgs: make([]*pb.ManipulationReqMsg, len(msgMap)),
|
|
wg: sync.WaitGroup{},
|
|
proxy: s,
|
|
}
|
|
for _, v := range msgMap {
|
|
mReq.msgs = append(mReq.msgs, v)
|
|
}
|
|
if st := mReq.PreExecute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS { //do nothing
|
|
return &servicepb.IntegerRangeResponse{
|
|
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
|
|
}, nil
|
|
}
|
|
if st := mReq.Execute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS { // push into chan
|
|
return &servicepb.IntegerRangeResponse{
|
|
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
|
|
}, nil
|
|
}
|
|
if st := mReq.PostExecute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS { //post to pulsar
|
|
return &servicepb.IntegerRangeResponse{
|
|
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
|
|
}, nil
|
|
}
|
|
|
|
if st := mReq.WaitToFinish(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
log.Printf("Wait to finish failed, error code = %d", st.ErrorCode)
|
|
}
|
|
|
|
return &servicepb.IntegerRangeResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_SUCCESS,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *proxyServer) Search(ctx context.Context, req *servicepb.Query) (*servicepb.QueryResult, error) {
|
|
qm := &queryReq{
|
|
SearchRequest: internalpb.SearchRequest{
|
|
ReqType: internalpb.ReqType_kSearch,
|
|
ProxyId: s.proxyId,
|
|
ReqId: s.queryId.Add(1),
|
|
Timestamp: 0,
|
|
ResultChannelId: 0,
|
|
},
|
|
proxy: s,
|
|
}
|
|
log.Printf("search on collection %s, proxy id = %d, query id = %d", req.CollectionName, qm.ProxyId, qm.ReqId)
|
|
if st := qm.PreExecute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &servicepb.QueryResult{
|
|
Status: &st,
|
|
}, nil
|
|
}
|
|
if st := qm.Execute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &servicepb.QueryResult{
|
|
Status: &st,
|
|
}, nil
|
|
}
|
|
if st := qm.PostExecute(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &servicepb.QueryResult{
|
|
Status: &st,
|
|
}, nil
|
|
}
|
|
if st := qm.WaitToFinish(); st.ErrorCode != commonpb.ErrorCode_SUCCESS {
|
|
return &servicepb.QueryResult{
|
|
Status: &st,
|
|
}, nil
|
|
}
|
|
return s.reduceResults(qm), nil
|
|
}
|
|
|
|
//check if proxySerer is set correct
|
|
func (s *proxyServer) check() error {
|
|
if len(s.address) == 0 {
|
|
return fmt.Errorf("proxy address is unset")
|
|
}
|
|
if len(s.masterAddress) == 0 {
|
|
return fmt.Errorf("master address is unset")
|
|
}
|
|
if len(s.rootPath) == 0 {
|
|
return fmt.Errorf("root path for etcd is unset")
|
|
}
|
|
if len(s.pulsarAddr) == 0 {
|
|
return fmt.Errorf("pulsar address is unset")
|
|
}
|
|
if len(s.readerTopics) == 0 {
|
|
return fmt.Errorf("reader topics is unset")
|
|
}
|
|
if len(s.deleteTopic) == 0 {
|
|
return fmt.Errorf("delete topic is unset")
|
|
}
|
|
if len(s.queryTopic) == 0 {
|
|
return fmt.Errorf("query topic is unset")
|
|
}
|
|
if len(s.resultTopic) == 0 {
|
|
return fmt.Errorf("result topic is unset")
|
|
}
|
|
if len(s.resultGroup) == 0 {
|
|
return fmt.Errorf("result group is unset")
|
|
}
|
|
if s.numReaderNode <= 0 {
|
|
return fmt.Errorf("number of reader nodes is unset")
|
|
}
|
|
if s.proxyId <= 0 {
|
|
return fmt.Errorf("proxyId is unset")
|
|
}
|
|
log.Printf("proxy id = %d", s.proxyId)
|
|
if s.getTimestamp == nil {
|
|
return fmt.Errorf("getTimestamp is unset")
|
|
}
|
|
if s.client == nil {
|
|
return fmt.Errorf("etcd client is unset")
|
|
}
|
|
if s.ctx == nil {
|
|
return fmt.Errorf("context is unset")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *proxyServer) getSegmentId(channelId int32, colName string) (int64, error) {
|
|
s.collectionMux.Lock()
|
|
defer s.collectionMux.Unlock()
|
|
colId, ok := s.nameCollectionId[colName]
|
|
if !ok {
|
|
return 0, status.Errorf(codes.Unknown, "can't get collection id of %s", colName)
|
|
}
|
|
colInfo, ok := s.collectionList[colId]
|
|
if !ok {
|
|
return 0, status.Errorf(codes.Unknown, "can't get collection, name = %s, id = %d", colName, colId)
|
|
}
|
|
for _, segId := range colInfo.SegmentIds {
|
|
_, ok := s.segmentList[segId]
|
|
if !ok {
|
|
return 0, status.Errorf(codes.Unknown, "can't get segment of %d", segId)
|
|
}
|
|
return segId, nil
|
|
}
|
|
return 0, status.Errorf(codes.Unknown, "can't get segment id, channel id = %d", channelId)
|
|
}
|
|
|
|
func (s *proxyServer) connectMaster() error {
|
|
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
|
|
conn, err := grpc.DialContext(ctx, s.masterAddress, grpc.WithInsecure(), grpc.WithBlock())
|
|
if err != nil {
|
|
log.Printf("Connect to master failed, error= %v", err)
|
|
return err
|
|
}
|
|
log.Printf("Connected to master, master_addr=%s", s.masterAddress)
|
|
s.masterConn = conn
|
|
s.masterClient = mpb.NewMasterClient(conn)
|
|
return nil
|
|
}
|
|
|
|
func (s *proxyServer) Close() {
|
|
s.client.Close()
|
|
s.masterConn.Close()
|
|
s.grpcServer.Stop()
|
|
}
|
|
|
|
func (s *proxyServer) StartGrpcServer() error {
|
|
lis, err := net.Listen("tcp", s.address)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go func() {
|
|
s.wg.Add(1)
|
|
defer s.wg.Done()
|
|
server := grpc.NewServer()
|
|
servicepb.RegisterMilvusServiceServer(server, s)
|
|
err := server.Serve(lis)
|
|
if err != nil {
|
|
log.Fatalf("Proxy grpc server fatal error=%v", err)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (s *proxyServer) WatchEtcd() error {
|
|
s.collectionMux.Lock()
|
|
defer s.collectionMux.Unlock()
|
|
|
|
cos, err := s.client.Get(s.ctx, s.rootPath+"/"+keyCollectionPath, etcd.WithPrefix())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, cob := range cos.Kvs {
|
|
// TODO: simplify collection struct
|
|
var co etcdpb.CollectionMeta
|
|
var mco collection.Collection
|
|
if err := json.Unmarshal(cob.Value, &mco); err != nil {
|
|
return err
|
|
}
|
|
proto.UnmarshalText(mco.GrpcMarshalString, &co)
|
|
s.nameCollectionId[co.Schema.Name] = co.Id
|
|
s.collectionList[co.Id] = &co
|
|
log.Printf("watch collection, name = %s, id = %d", co.Schema.Name, co.Id)
|
|
}
|
|
segs, err := s.client.Get(s.ctx, s.rootPath+"/"+keySegmentPath, etcd.WithPrefix())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, segb := range segs.Kvs {
|
|
var seg etcdpb.SegmentMeta
|
|
if err := json.Unmarshal(segb.Value, &seg); err != nil {
|
|
return err
|
|
}
|
|
s.segmentList[seg.SegmentId] = &seg
|
|
log.Printf("watch segment id = %d\n", seg.SegmentId)
|
|
}
|
|
|
|
cow := s.client.Watch(s.ctx, s.rootPath+"/"+keyCollectionPath, etcd.WithPrefix(), etcd.WithRev(cos.Header.Revision+1))
|
|
segw := s.client.Watch(s.ctx, s.rootPath+"/"+keySegmentPath, etcd.WithPrefix(), etcd.WithRev(segs.Header.Revision+1))
|
|
go func() {
|
|
s.wg.Add(1)
|
|
defer s.wg.Done()
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return
|
|
case coe := <-cow:
|
|
func() {
|
|
s.collectionMux.Lock()
|
|
defer s.collectionMux.Unlock()
|
|
for _, e := range coe.Events {
|
|
var co etcdpb.CollectionMeta
|
|
var mco collection.Collection
|
|
if err := json.Unmarshal(e.Kv.Value, &mco); err != nil {
|
|
log.Printf("unmarshal Collection failed, error = %v", err)
|
|
} else {
|
|
proto.UnmarshalText(mco.GrpcMarshalString, &co)
|
|
s.nameCollectionId[co.Schema.Name] = co.Id
|
|
s.collectionList[co.Id] = &co
|
|
log.Printf("watch collection, name = %s, id = %d", co.Schema.Name, co.Id)
|
|
}
|
|
}
|
|
}()
|
|
case sege := <-segw:
|
|
func() {
|
|
s.collectionMux.Lock()
|
|
defer s.collectionMux.Unlock()
|
|
for _, e := range sege.Events {
|
|
var seg etcdpb.SegmentMeta
|
|
if err := json.Unmarshal(e.Kv.Value, &seg); err != nil {
|
|
log.Printf("unmarshal Segment failed, error = %v", err)
|
|
} else {
|
|
s.segmentList[seg.SegmentId] = &seg
|
|
log.Printf("watch segment id = %d\n", seg.SegmentId)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func startProxyServer(srv *proxyServer) error {
|
|
if err := srv.check(); err != nil {
|
|
return err
|
|
}
|
|
srv.reqSch = &requestScheduler{}
|
|
if err := srv.restartManipulationRoutine(1024); err != nil {
|
|
return err
|
|
}
|
|
if err := srv.restartQueryRoutine(1024); err != nil {
|
|
return err
|
|
}
|
|
|
|
srv.nameCollectionId = make(map[string]int64)
|
|
srv.collectionList = make(map[int64]*etcdpb.CollectionMeta)
|
|
srv.segmentList = make(map[int64]*etcdpb.SegmentMeta)
|
|
|
|
if err := srv.connectMaster(); err != nil {
|
|
return err
|
|
}
|
|
if err := srv.WatchEtcd(); err != nil {
|
|
return err
|
|
}
|
|
|
|
srv.queryId.Store(time.Now().UnixNano())
|
|
|
|
return srv.StartGrpcServer()
|
|
}
|