mirror of https://github.com/milvus-io/milvus.git
QueryNodes send search/query results by rpc (#15223)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/15250/head
parent
b0e3460569
commit
675e6d352b
|
@ -141,3 +141,29 @@ func (c *Client) ReleaseDQLMessageStream(ctx context.Context, req *proxypb.Relea
|
|||
}
|
||||
return ret.(*commonpb.Status), err
|
||||
}
|
||||
|
||||
func (c *Client) SendSearchResult(ctx context.Context, results *internalpb.SearchResults) (*commonpb.Status, error) {
|
||||
ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) {
|
||||
if !funcutil.CheckCtxValid(ctx) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return client.(proxypb.ProxyClient).SendSearchResult(ctx, results)
|
||||
})
|
||||
if err != nil || ret == nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret.(*commonpb.Status), err
|
||||
}
|
||||
|
||||
func (c *Client) SendRetrieveResult(ctx context.Context, results *internalpb.RetrieveResults) (*commonpb.Status, error) {
|
||||
ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) {
|
||||
if !funcutil.CheckCtxValid(ctx) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return client.(proxypb.ProxyClient).SendRetrieveResult(ctx, results)
|
||||
})
|
||||
if err != nil || ret == nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret.(*commonpb.Status), err
|
||||
}
|
||||
|
|
|
@ -68,6 +68,12 @@ func Test_NewClient(t *testing.T) {
|
|||
|
||||
r4, err := client.ReleaseDQLMessageStream(ctx, nil)
|
||||
retCheck(retNotNil, r4, err)
|
||||
|
||||
r5, err := client.SendSearchResult(ctx, nil)
|
||||
retCheck(retNotNil, r5, err)
|
||||
|
||||
r6, err := client.SendRetrieveResult(ctx, nil)
|
||||
retCheck(retNotNil, r6, err)
|
||||
}
|
||||
|
||||
client.grpcClient = &mock.ClientBase{
|
||||
|
|
|
@ -571,3 +571,11 @@ func (s *Server) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.
|
|||
func (s *Server) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) {
|
||||
return s.proxy.GetFlushState(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Server) SendSearchResult(ctx context.Context, results *internalpb.SearchResults) (*commonpb.Status, error) {
|
||||
return s.proxy.SendSearchResult(ctx, results)
|
||||
}
|
||||
|
||||
func (s *Server) SendRetrieveResult(ctx context.Context, results *internalpb.RetrieveResults) (*commonpb.Status, error) {
|
||||
return s.proxy.SendRetrieveResult(ctx, results)
|
||||
}
|
||||
|
|
|
@ -636,6 +636,14 @@ func (m *MockProxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushSta
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
func Test_NewServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
|
|
@ -15,6 +15,9 @@ service Proxy {
|
|||
rpc GetDdChannel(internal.GetDdChannelRequest) returns (milvus.StringResponse) {}
|
||||
|
||||
rpc ReleaseDQLMessageStream(ReleaseDQLMessageStreamRequest) returns (common.Status) {}
|
||||
|
||||
rpc SendSearchResult(internal.SearchResults) returns (common.Status) {}
|
||||
rpc SendRetrieveResult(internal.RetrieveResults) returns (common.Status) {}
|
||||
}
|
||||
|
||||
message InvalidateCollMetaCacheRequest {
|
||||
|
|
|
@ -145,33 +145,36 @@ func init() {
|
|||
func init() { proto.RegisterFile("proxy.proto", fileDescriptor_700b50b08ed8dbaf) }
|
||||
|
||||
var fileDescriptor_700b50b08ed8dbaf = []byte{
|
||||
// 413 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x52, 0xd1, 0x0a, 0xd3, 0x30,
|
||||
0x14, 0x5d, 0xdd, 0x9c, 0x98, 0x95, 0x09, 0x41, 0xd8, 0xa8, 0x3a, 0x46, 0x05, 0x1d, 0x82, 0xeb,
|
||||
0xa8, 0x7e, 0xc1, 0x5a, 0x18, 0x03, 0x27, 0xda, 0xbd, 0xf9, 0x22, 0x69, 0x7b, 0xe9, 0x02, 0x69,
|
||||
0xd2, 0x35, 0xe9, 0xd0, 0x5f, 0xf0, 0xd9, 0x1f, 0xf1, 0x0f, 0xa5, 0x69, 0xb7, 0xd9, 0x6d, 0x9d,
|
||||
0xe8, 0x5b, 0xcf, 0xcd, 0xb9, 0x9c, 0x73, 0x6e, 0x0f, 0x1a, 0x64, 0xb9, 0xf8, 0xf6, 0x7d, 0x9e,
|
||||
0xe5, 0x42, 0x09, 0x8c, 0x53, 0xca, 0x0e, 0x85, 0xac, 0xd0, 0x5c, 0xbf, 0x58, 0x66, 0x24, 0xd2,
|
||||
0x54, 0xf0, 0x6a, 0x66, 0x0d, 0x29, 0x57, 0x90, 0x73, 0xc2, 0x6a, 0x6c, 0xfe, 0xb9, 0x61, 0xff,
|
||||
0x34, 0xd0, 0x64, 0xcd, 0x0f, 0x84, 0xd1, 0x98, 0x28, 0xf0, 0x04, 0x63, 0x1b, 0x50, 0xc4, 0x23,
|
||||
0xd1, 0x0e, 0x02, 0xd8, 0x17, 0x20, 0x15, 0x5e, 0xa0, 0x5e, 0x48, 0x24, 0x8c, 0x8d, 0xa9, 0x31,
|
||||
0x1b, 0xb8, 0xcf, 0xe7, 0x0d, 0xc5, 0x5a, 0x6a, 0x23, 0x93, 0x25, 0x91, 0x10, 0x68, 0x26, 0x1e,
|
||||
0xa1, 0x47, 0x71, 0xf8, 0x95, 0x93, 0x14, 0xc6, 0x0f, 0xa6, 0xc6, 0xec, 0x71, 0xd0, 0x8f, 0xc3,
|
||||
0x8f, 0x24, 0x05, 0xfc, 0x1a, 0x3d, 0x89, 0x04, 0x63, 0x10, 0x29, 0x2a, 0x78, 0x45, 0xe8, 0x6a,
|
||||
0xc2, 0xf0, 0x3c, 0x2e, 0x89, 0xf6, 0x0f, 0x03, 0x4d, 0x02, 0x60, 0x40, 0x24, 0xf8, 0x9f, 0x3f,
|
||||
0x6c, 0x40, 0x4a, 0x92, 0xc0, 0x56, 0xe5, 0x40, 0xd2, 0xff, 0xb7, 0x85, 0x51, 0x2f, 0x0e, 0xd7,
|
||||
0xbe, 0xf6, 0xd4, 0x0d, 0xf4, 0x37, 0xb6, 0x91, 0x79, 0x96, 0x5e, 0xfb, 0xda, 0x4e, 0x37, 0x68,
|
||||
0xcc, 0xdc, 0x5f, 0x3d, 0xf4, 0xf0, 0x53, 0x79, 0x59, 0x9c, 0x21, 0xbc, 0x02, 0xe5, 0x89, 0x34,
|
||||
0x13, 0x1c, 0xb8, 0xda, 0x2a, 0xa2, 0x40, 0xe2, 0x45, 0x53, 0xfb, 0x74, 0xef, 0x6b, 0x6a, 0xed,
|
||||
0xdd, 0x7a, 0xd5, 0xb2, 0x71, 0x41, 0xb7, 0x3b, 0x78, 0x8f, 0x9e, 0xae, 0x40, 0x43, 0x2a, 0x15,
|
||||
0x8d, 0xa4, 0xb7, 0x23, 0x9c, 0x03, 0xc3, 0x6e, 0xbb, 0xe6, 0x15, 0xf9, 0xa8, 0xfa, 0xb2, 0xb9,
|
||||
0x53, 0x83, 0xad, 0xca, 0x29, 0x4f, 0x02, 0x90, 0x99, 0xe0, 0x12, 0xec, 0x0e, 0xce, 0xd1, 0x8b,
|
||||
0x66, 0x23, 0xaa, 0x43, 0x9c, 0x7a, 0x71, 0xa9, 0x5d, 0xd5, 0xf1, 0x7e, 0x89, 0xac, 0x67, 0x37,
|
||||
0xff, 0x4f, 0x69, 0xb5, 0x28, 0x63, 0x12, 0x64, 0xae, 0x40, 0xf9, 0xf1, 0x31, 0xde, 0x9b, 0xf6,
|
||||
0x78, 0x27, 0xd2, 0x3f, 0xc6, 0x62, 0x68, 0xd4, 0xd2, 0xa8, 0xdb, 0x81, 0xee, 0xd7, 0xef, 0x2f,
|
||||
0x81, 0x96, 0xef, 0xbf, 0xb8, 0x09, 0x55, 0xbb, 0x22, 0x2c, 0x5f, 0x9c, 0x8a, 0xfa, 0x96, 0x8a,
|
||||
0xfa, 0xcb, 0x39, 0x06, 0x72, 0xf4, 0xb6, 0xa3, 0x15, 0xb3, 0x30, 0xec, 0x6b, 0xf8, 0xee, 0x77,
|
||||
0x00, 0x00, 0x00, 0xff, 0xff, 0x46, 0xe1, 0x7c, 0xd9, 0xe3, 0x03, 0x00, 0x00,
|
||||
// 460 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x52, 0x61, 0x6b, 0x13, 0x41,
|
||||
0x10, 0xed, 0x99, 0xb6, 0xe2, 0x36, 0x54, 0x59, 0x84, 0x96, 0xa8, 0xa5, 0x9c, 0xa2, 0x45, 0x30,
|
||||
0x29, 0xd1, 0x5f, 0xd0, 0x04, 0x42, 0xc0, 0x88, 0xee, 0x7d, 0x10, 0xf4, 0x83, 0xcc, 0xdd, 0x0d,
|
||||
0xc9, 0xc2, 0xde, 0xee, 0x75, 0x77, 0x2e, 0xe8, 0x5f, 0xf0, 0xb3, 0xff, 0xd1, 0xbf, 0x21, 0xb7,
|
||||
0x77, 0x49, 0x7b, 0x6d, 0x2f, 0xc1, 0x7e, 0xbb, 0x99, 0x7d, 0x33, 0xef, 0xbd, 0xb9, 0xc7, 0x0e,
|
||||
0x72, 0x6b, 0x7e, 0xfe, 0xea, 0xe7, 0xd6, 0x90, 0xe1, 0x3c, 0x93, 0x6a, 0x59, 0xb8, 0xaa, 0xea,
|
||||
0xfb, 0x97, 0x5e, 0x37, 0x31, 0x59, 0x66, 0x74, 0xd5, 0xeb, 0x1d, 0x4a, 0x4d, 0x68, 0x35, 0xa8,
|
||||
0xba, 0xee, 0x5e, 0x9f, 0x08, 0xff, 0x04, 0xec, 0x64, 0xaa, 0x97, 0xa0, 0x64, 0x0a, 0x84, 0x23,
|
||||
0xa3, 0xd4, 0x0c, 0x09, 0x46, 0x90, 0x2c, 0x50, 0xe0, 0x65, 0x81, 0x8e, 0xf8, 0x39, 0xdb, 0x8d,
|
||||
0xc1, 0xe1, 0x71, 0x70, 0x1a, 0x9c, 0x1d, 0x0c, 0x9f, 0xf7, 0x1b, 0x8c, 0x35, 0xd5, 0xcc, 0xcd,
|
||||
0x2f, 0xc0, 0xa1, 0xf0, 0x48, 0x7e, 0xc4, 0x1e, 0xa6, 0xf1, 0x0f, 0x0d, 0x19, 0x1e, 0x3f, 0x38,
|
||||
0x0d, 0xce, 0x1e, 0x89, 0xfd, 0x34, 0xfe, 0x04, 0x19, 0xf2, 0x37, 0xec, 0x71, 0x62, 0x94, 0xc2,
|
||||
0x84, 0xa4, 0xd1, 0x15, 0xa0, 0xe3, 0x01, 0x87, 0x57, 0xed, 0x12, 0x18, 0xfe, 0x0e, 0xd8, 0x89,
|
||||
0x40, 0x85, 0xe0, 0x70, 0xfc, 0xe5, 0xe3, 0x0c, 0x9d, 0x83, 0x39, 0x46, 0x64, 0x11, 0xb2, 0xfb,
|
||||
0xcb, 0xe2, 0x6c, 0x37, 0x8d, 0xa7, 0x63, 0xaf, 0xa9, 0x23, 0xfc, 0x37, 0x0f, 0x59, 0xf7, 0x8a,
|
||||
0x7a, 0x3a, 0xf6, 0x72, 0x3a, 0xa2, 0xd1, 0x1b, 0xfe, 0xdd, 0x63, 0x7b, 0x9f, 0xcb, 0xcb, 0xf2,
|
||||
0x9c, 0xf1, 0x09, 0xd2, 0xc8, 0x64, 0xb9, 0xd1, 0xa8, 0x29, 0x22, 0x20, 0x74, 0xfc, 0xbc, 0xc9,
|
||||
0xbd, 0xbe, 0xf7, 0x6d, 0x68, 0xad, 0xbd, 0xf7, 0xba, 0x65, 0xe2, 0x06, 0x3c, 0xdc, 0xe1, 0x97,
|
||||
0xec, 0xe9, 0x04, 0x7d, 0x29, 0x1d, 0xc9, 0xc4, 0x8d, 0x16, 0xa0, 0x35, 0x2a, 0x3e, 0x6c, 0xe7,
|
||||
0xbc, 0x05, 0x5e, 0xb1, 0xbe, 0x6c, 0xce, 0xd4, 0x45, 0x44, 0x56, 0xea, 0xb9, 0x40, 0x97, 0x1b,
|
||||
0xed, 0x30, 0xdc, 0xe1, 0x96, 0xbd, 0x68, 0x26, 0xa2, 0x3a, 0xc4, 0x3a, 0x17, 0x37, 0xb9, 0xab,
|
||||
0x38, 0x6e, 0x0e, 0x51, 0xef, 0xd9, 0x9d, 0xff, 0xa7, 0x94, 0x5a, 0x94, 0x36, 0x81, 0x75, 0x27,
|
||||
0x48, 0xe3, 0x74, 0x65, 0xef, 0x6d, 0xbb, 0xbd, 0x35, 0xe8, 0x3f, 0x6d, 0x29, 0x76, 0xd4, 0x92,
|
||||
0xa8, 0xbb, 0x0d, 0x6d, 0x8e, 0xdf, 0x36, 0x43, 0x5f, 0xd9, 0x93, 0x08, 0x75, 0x1a, 0x21, 0xd8,
|
||||
0x64, 0x21, 0xd0, 0x15, 0x8a, 0xf8, 0xab, 0x16, 0x53, 0xd7, 0x41, 0x6e, 0xdb, 0xe2, 0xef, 0x8c,
|
||||
0x97, 0x8b, 0x05, 0x92, 0x95, 0xb8, 0xc4, 0x7a, 0x75, 0x5b, 0xa0, 0x9a, 0xb0, 0x6d, 0xcb, 0x2f,
|
||||
0x3e, 0x7c, 0x1b, 0xce, 0x25, 0x2d, 0x8a, 0xb8, 0x7c, 0x19, 0x54, 0xd0, 0x77, 0xd2, 0xd4, 0x5f,
|
||||
0x83, 0xd5, 0xda, 0x81, 0x9f, 0x1e, 0xf8, 0x3b, 0xe5, 0x71, 0xbc, 0xef, 0xcb, 0xf7, 0xff, 0x02,
|
||||
0x00, 0x00, 0xff, 0xff, 0x48, 0x84, 0x3c, 0x39, 0x99, 0x04, 0x00, 0x00,
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
|
@ -191,6 +194,8 @@ type ProxyClient interface {
|
|||
InvalidateCollectionMetaCache(ctx context.Context, in *InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error)
|
||||
GetDdChannel(ctx context.Context, in *internalpb.GetDdChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error)
|
||||
ReleaseDQLMessageStream(ctx context.Context, in *ReleaseDQLMessageStreamRequest, opts ...grpc.CallOption) (*commonpb.Status, error)
|
||||
SendSearchResult(ctx context.Context, in *internalpb.SearchResults, opts ...grpc.CallOption) (*commonpb.Status, error)
|
||||
SendRetrieveResult(ctx context.Context, in *internalpb.RetrieveResults, opts ...grpc.CallOption) (*commonpb.Status, error)
|
||||
}
|
||||
|
||||
type proxyClient struct {
|
||||
|
@ -246,6 +251,24 @@ func (c *proxyClient) ReleaseDQLMessageStream(ctx context.Context, in *ReleaseDQ
|
|||
return out, nil
|
||||
}
|
||||
|
||||
func (c *proxyClient) SendSearchResult(ctx context.Context, in *internalpb.SearchResults, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
out := new(commonpb.Status)
|
||||
err := c.cc.Invoke(ctx, "/milvus.proto.proxy.Proxy/SendSearchResult", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *proxyClient) SendRetrieveResult(ctx context.Context, in *internalpb.RetrieveResults, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
out := new(commonpb.Status)
|
||||
err := c.cc.Invoke(ctx, "/milvus.proto.proxy.Proxy/SendRetrieveResult", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ProxyServer is the server API for Proxy service.
|
||||
type ProxyServer interface {
|
||||
GetComponentStates(context.Context, *internalpb.GetComponentStatesRequest) (*internalpb.ComponentStates, error)
|
||||
|
@ -253,6 +276,8 @@ type ProxyServer interface {
|
|||
InvalidateCollectionMetaCache(context.Context, *InvalidateCollMetaCacheRequest) (*commonpb.Status, error)
|
||||
GetDdChannel(context.Context, *internalpb.GetDdChannelRequest) (*milvuspb.StringResponse, error)
|
||||
ReleaseDQLMessageStream(context.Context, *ReleaseDQLMessageStreamRequest) (*commonpb.Status, error)
|
||||
SendSearchResult(context.Context, *internalpb.SearchResults) (*commonpb.Status, error)
|
||||
SendRetrieveResult(context.Context, *internalpb.RetrieveResults) (*commonpb.Status, error)
|
||||
}
|
||||
|
||||
// UnimplementedProxyServer can be embedded to have forward compatible implementations.
|
||||
|
@ -274,6 +299,12 @@ func (*UnimplementedProxyServer) GetDdChannel(ctx context.Context, req *internal
|
|||
func (*UnimplementedProxyServer) ReleaseDQLMessageStream(ctx context.Context, req *ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ReleaseDQLMessageStream not implemented")
|
||||
}
|
||||
func (*UnimplementedProxyServer) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SendSearchResult not implemented")
|
||||
}
|
||||
func (*UnimplementedProxyServer) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SendRetrieveResult not implemented")
|
||||
}
|
||||
|
||||
func RegisterProxyServer(s *grpc.Server, srv ProxyServer) {
|
||||
s.RegisterService(&_Proxy_serviceDesc, srv)
|
||||
|
@ -369,6 +400,42 @@ func _Proxy_ReleaseDQLMessageStream_Handler(srv interface{}, ctx context.Context
|
|||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Proxy_SendSearchResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(internalpb.SearchResults)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ProxyServer).SendSearchResult(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/milvus.proto.proxy.Proxy/SendSearchResult",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ProxyServer).SendSearchResult(ctx, req.(*internalpb.SearchResults))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Proxy_SendRetrieveResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(internalpb.RetrieveResults)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ProxyServer).SendRetrieveResult(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/milvus.proto.proxy.Proxy/SendRetrieveResult",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ProxyServer).SendRetrieveResult(ctx, req.(*internalpb.RetrieveResults))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _Proxy_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "milvus.proto.proxy.Proxy",
|
||||
HandlerType: (*ProxyServer)(nil),
|
||||
|
@ -393,6 +460,14 @@ var _Proxy_serviceDesc = grpc.ServiceDesc{
|
|||
MethodName: "ReleaseDQLMessageStream",
|
||||
Handler: _Proxy_ReleaseDQLMessageStream_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SendSearchResult",
|
||||
Handler: _Proxy_SendSearchResult_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SendRetrieveResult",
|
||||
Handler: _Proxy_SendRetrieveResult_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "proxy.proto",
|
||||
|
|
|
@ -90,6 +90,9 @@ type Proxy struct {
|
|||
|
||||
msFactory msgstream.Factory
|
||||
|
||||
searchResultCh chan *internalpb.SearchResults
|
||||
retrieveResultCh chan *internalpb.RetrieveResults
|
||||
|
||||
// Add callback functions at different stages
|
||||
startCallbacks []func()
|
||||
closeCallbacks []func()
|
||||
|
@ -99,10 +102,13 @@ type Proxy struct {
|
|||
func NewProxy(ctx context.Context, factory msgstream.Factory) (*Proxy, error) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
n := 1024 // better to be configurable
|
||||
node := &Proxy{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
msFactory: factory,
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
msFactory: factory,
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
retrieveResultCh: make(chan *internalpb.RetrieveResults, n),
|
||||
}
|
||||
node.UpdateStateCode(internalpb.StateCode_Abnormal)
|
||||
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
|
||||
|
@ -231,7 +237,9 @@ func (node *Proxy) Init() error {
|
|||
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
log.Debug("create task scheduler", zap.String("role", typeutil.ProxyRole))
|
||||
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
|
||||
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory,
|
||||
schedOptWithSearchResultCh(node.searchResultCh),
|
||||
schedOptWithRetrieveResultCh(node.retrieveResultCh))
|
||||
if err != nil {
|
||||
log.Warn("failed to create task scheduler", zap.Error(err), zap.String("role", typeutil.ProxyRole))
|
||||
return err
|
||||
|
|
|
@ -23,12 +23,21 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
|
@ -324,8 +333,97 @@ func runIndexNode(ctx context.Context, localMsg bool, alias string) *grpcindexno
|
|||
return in
|
||||
}
|
||||
|
||||
type proxyTestServer struct {
|
||||
*Proxy
|
||||
grpcServer *grpc.Server
|
||||
ch chan error
|
||||
}
|
||||
|
||||
func newProxyTestServer(node *Proxy) *proxyTestServer {
|
||||
return &proxyTestServer{
|
||||
Proxy: node,
|
||||
grpcServer: nil,
|
||||
ch: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *proxyTestServer) GetComponentStates(ctx context.Context, request *internalpb.GetComponentStatesRequest) (*internalpb.ComponentStates, error) {
|
||||
return s.Proxy.GetComponentStates(ctx)
|
||||
}
|
||||
|
||||
func (s *proxyTestServer) GetStatisticsChannel(ctx context.Context, request *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) {
|
||||
return s.Proxy.GetStatisticsChannel(ctx)
|
||||
}
|
||||
|
||||
func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
var p paramtable.GrpcServerConfig
|
||||
p.InitOnce(typeutil.ProxyRole)
|
||||
Params.InitOnce()
|
||||
Params.ProxyCfg.NetworkAddress = p.GetAddress()
|
||||
Params.ProxyCfg.Refresh()
|
||||
|
||||
var kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection
|
||||
PermitWithoutStream: true, // Allow pings even when there are no active streams
|
||||
}
|
||||
|
||||
var kasp = keepalive.ServerParameters{
|
||||
Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active
|
||||
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
|
||||
}
|
||||
|
||||
log.Debug("Proxy server listen on tcp", zap.Int("port", p.Port))
|
||||
lis, err := net.Listen("tcp", ":"+strconv.Itoa(p.Port))
|
||||
if err != nil {
|
||||
log.Warn("Proxy server failed to listen on", zap.Error(err), zap.Int("port", p.Port))
|
||||
s.ch <- err
|
||||
return
|
||||
}
|
||||
log.Debug("Proxy server already listen on tcp", zap.Int("port", p.Port))
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
opts := trace.GetInterceptorOpts()
|
||||
s.grpcServer = grpc.NewServer(
|
||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||
grpc.KeepaliveParams(kasp),
|
||||
grpc.MaxRecvMsgSize(p.ServerMaxRecvSize),
|
||||
grpc.MaxSendMsgSize(p.ServerMaxSendSize),
|
||||
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)),
|
||||
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...)))
|
||||
proxypb.RegisterProxyServer(s.grpcServer, s)
|
||||
milvuspb.RegisterMilvusServiceServer(s.grpcServer, s)
|
||||
|
||||
log.Debug("create Proxy grpc server",
|
||||
zap.Any("enforcement policy", kaep),
|
||||
zap.Any("server parameters", kasp))
|
||||
|
||||
log.Debug("waiting for Proxy grpc server to be ready")
|
||||
go funcutil.CheckGrpcReady(ctx, s.ch)
|
||||
|
||||
log.Debug("Proxy grpc server has been ready, serve grpc requests on listen")
|
||||
if err := s.grpcServer.Serve(lis); err != nil {
|
||||
log.Warn("failed to serve on Proxy's listener", zap.Error(err))
|
||||
s.ch <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *proxyTestServer) waitForGrpcReady() error {
|
||||
return <-s.ch
|
||||
}
|
||||
|
||||
func (s *proxyTestServer) gracefulStop() {
|
||||
if s.grpcServer != nil {
|
||||
s.grpcServer.GracefulStop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
var err error
|
||||
var wg sync.WaitGroup
|
||||
|
||||
path := "/tmp/milvus/rocksmq" + funcutil.GenRandomStr()
|
||||
err = os.Setenv("ROCKSMQ_PATH", path)
|
||||
|
@ -419,6 +517,7 @@ func TestProxy(t *testing.T) {
|
|||
proxy, err := NewProxy(ctx, factory)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proxy)
|
||||
|
||||
Params.Init()
|
||||
log.Info("Initialize parameter table of Proxy")
|
||||
|
||||
|
@ -426,6 +525,12 @@ func TestProxy(t *testing.T) {
|
|||
defer etcdcli.Close()
|
||||
assert.NoError(t, err)
|
||||
proxy.SetEtcdClient(etcdcli)
|
||||
|
||||
testServer := newProxyTestServer(proxy)
|
||||
wg.Add(1)
|
||||
go testServer.startGrpc(ctx, &wg)
|
||||
assert.NoError(t, testServer.waitForGrpcReady())
|
||||
|
||||
rootCoordClient, err := rcc.NewClient(ctx, Params.BaseParams.MetaRootPath, etcdcli)
|
||||
assert.NoError(t, err)
|
||||
err = rootCoordClient.Init()
|
||||
|
@ -680,8 +785,6 @@ func TestProxy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("create collection", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
|
@ -1249,12 +1352,9 @@ func TestProxy(t *testing.T) {
|
|||
defer wg.Done()
|
||||
req := constructSearchRequest()
|
||||
|
||||
//resp, err := proxy.Search(ctx, req)
|
||||
_, err := proxy.Search(ctx, req)
|
||||
resp, err := proxy.Search(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
// FIXME(dragondriver)
|
||||
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
// TODO(dragondriver): compare search result
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
|
@ -2565,6 +2665,8 @@ func TestProxy(t *testing.T) {
|
|||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
})
|
||||
|
||||
testServer.gracefulStop()
|
||||
|
||||
wg.Wait()
|
||||
cancel()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
)
|
||||
|
||||
func (node *Proxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
|
||||
node.searchResultCh <- req
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
|
||||
node.retrieveResultCh <- req
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
|
@ -387,12 +387,31 @@ type taskScheduler struct {
|
|||
cancel context.CancelFunc
|
||||
|
||||
msFactory msgstream.Factory
|
||||
|
||||
searchResultCh chan *internalpb.SearchResults
|
||||
retrieveResultCh chan *internalpb.RetrieveResults
|
||||
}
|
||||
|
||||
type schedOpt func(*taskScheduler)
|
||||
|
||||
func schedOptWithSearchResultCh(ch chan *internalpb.SearchResults) schedOpt {
|
||||
return func(sched *taskScheduler) {
|
||||
sched.searchResultCh = ch
|
||||
}
|
||||
}
|
||||
|
||||
func schedOptWithRetrieveResultCh(ch chan *internalpb.RetrieveResults) schedOpt {
|
||||
return func(sched *taskScheduler) {
|
||||
sched.retrieveResultCh = ch
|
||||
}
|
||||
}
|
||||
|
||||
func newTaskScheduler(ctx context.Context,
|
||||
idAllocatorIns idAllocatorInterface,
|
||||
tsoAllocatorIns tsoAllocator,
|
||||
factory msgstream.Factory) (*taskScheduler, error) {
|
||||
factory msgstream.Factory,
|
||||
opts ...schedOpt,
|
||||
) (*taskScheduler, error) {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
s := &taskScheduler{
|
||||
ctx: ctx1,
|
||||
|
@ -403,6 +422,10 @@ func newTaskScheduler(ctx context.Context,
|
|||
s.dmQueue = newDmTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
s.dqQueue = newDqTaskQueue(tsoAllocatorIns, idAllocatorIns)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -557,6 +580,7 @@ func newSearchResultBuf(msgID UniqueID) *searchResultBuf {
|
|||
receivedSealedSegmentIDsSet: make(map[interface{}]struct{}),
|
||||
receivedGlobalSegmentIDsSet: make(map[interface{}]struct{}),
|
||||
haveError: false,
|
||||
msgID: msgID,
|
||||
},
|
||||
resultBuf: make([]*internalpb.SearchResults, 0),
|
||||
}
|
||||
|
@ -570,6 +594,7 @@ func newQueryResultBuf(msgID UniqueID) *queryResultBuf {
|
|||
receivedSealedSegmentIDsSet: make(map[interface{}]struct{}),
|
||||
receivedGlobalSegmentIDsSet: make(map[interface{}]struct{}),
|
||||
haveError: false,
|
||||
msgID: msgID,
|
||||
},
|
||||
resultBuf: make([]*internalpb.RetrieveResults, 0),
|
||||
}
|
||||
|
@ -632,6 +657,161 @@ func (qr *queryResultBuf) addPartialResult(result *internalpb.RetrieveResults) {
|
|||
result.GlobalSealedSegmentIDs)
|
||||
}
|
||||
|
||||
func (sched *taskScheduler) collectionResultLoopV2() {
|
||||
defer sched.wg.Done()
|
||||
|
||||
searchResultBufs := make(map[UniqueID]*searchResultBuf)
|
||||
searchResultBufFlags := newIDCache(Params.ProxyCfg.BufFlagExpireTime, Params.ProxyCfg.BufFlagCleanupInterval) // if value is true, we can ignore searchResult
|
||||
queryResultBufs := make(map[UniqueID]*queryResultBuf)
|
||||
queryResultBufFlags := newIDCache(Params.ProxyCfg.BufFlagExpireTime, Params.ProxyCfg.BufFlagCleanupInterval) // if value is true, we can ignore queryResult
|
||||
|
||||
processSearchResult := func(results *internalpb.SearchResults) error {
|
||||
reqID := results.Base.MsgID
|
||||
|
||||
ignoreThisResult, ok := searchResultBufFlags.Get(reqID)
|
||||
if !ok {
|
||||
searchResultBufFlags.Set(reqID, false)
|
||||
ignoreThisResult = false
|
||||
}
|
||||
if ignoreThisResult {
|
||||
log.Debug("got a search result, but we should ignore", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("got a search result", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
|
||||
if t == nil {
|
||||
log.Debug("got a search result, but not in task scheduler", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
delete(searchResultBufs, reqID)
|
||||
searchResultBufFlags.Set(reqID, true)
|
||||
}
|
||||
|
||||
st, ok := t.(*searchTask)
|
||||
if !ok {
|
||||
log.Debug("got a search result, but the related task is not of search task", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
delete(searchResultBufs, reqID)
|
||||
searchResultBufFlags.Set(reqID, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
resultBuf, ok := searchResultBufs[reqID]
|
||||
if !ok {
|
||||
log.Debug("first receive search result of this task", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
|
||||
resultBuf = newSearchResultBuf(reqID)
|
||||
vchans, err := st.getVChannels()
|
||||
if err != nil {
|
||||
delete(searchResultBufs, reqID)
|
||||
log.Warn("failed to get virtual channels", zap.String("role", typeutil.ProxyRole), zap.Error(err), zap.Int64("reqID", reqID))
|
||||
return err
|
||||
}
|
||||
for _, vchan := range vchans {
|
||||
resultBuf.usedVChans[vchan] = struct{}{}
|
||||
}
|
||||
searchResultBufs[reqID] = resultBuf
|
||||
}
|
||||
resultBuf.addPartialResult(results)
|
||||
|
||||
colName := t.(*searchTask).query.CollectionName
|
||||
log.Debug("process search result", zap.String("role", typeutil.ProxyRole), zap.String("collection", colName), zap.Int64("reqID", reqID), zap.Int("answer cnt", len(searchResultBufs[reqID].resultBuf)))
|
||||
|
||||
if resultBuf.readyToReduce() {
|
||||
log.Debug("process search result, ready to reduce", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
|
||||
searchResultBufFlags.Set(reqID, true)
|
||||
st.resultBuf <- resultBuf.resultBuf
|
||||
delete(searchResultBufs, reqID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
processRetrieveResult := func(results *internalpb.RetrieveResults) error {
|
||||
reqID := results.Base.MsgID
|
||||
|
||||
ignoreThisResult, ok := queryResultBufFlags.Get(reqID)
|
||||
if !ok {
|
||||
queryResultBufFlags.Set(reqID, false)
|
||||
ignoreThisResult = false
|
||||
}
|
||||
if ignoreThisResult {
|
||||
log.Debug("got a retrieve result, but we should ignore", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("got a retrieve result", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
|
||||
t := sched.getTaskByReqID(reqID)
|
||||
|
||||
if t == nil {
|
||||
log.Debug("got a retrieve result, but not in task scheduler", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
delete(queryResultBufs, reqID)
|
||||
queryResultBufFlags.Set(reqID, true)
|
||||
}
|
||||
|
||||
st, ok := t.(*queryTask)
|
||||
if !ok {
|
||||
log.Debug("got a retrieve result, but the related task is not of retrieve task", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
|
||||
delete(queryResultBufs, reqID)
|
||||
queryResultBufFlags.Set(reqID, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
resultBuf, ok := queryResultBufs[reqID]
|
||||
if !ok {
|
||||
log.Debug("first receive retrieve result of this task", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
|
||||
resultBuf = newQueryResultBuf(reqID)
|
||||
vchans, err := st.getVChannels()
|
||||
if err != nil {
|
||||
delete(queryResultBufs, reqID)
|
||||
log.Warn("failed to get virtual channels", zap.String("role", typeutil.ProxyRole), zap.Error(err), zap.Int64("reqID", reqID))
|
||||
return err
|
||||
}
|
||||
for _, vchan := range vchans {
|
||||
resultBuf.usedVChans[vchan] = struct{}{}
|
||||
}
|
||||
queryResultBufs[reqID] = resultBuf
|
||||
}
|
||||
resultBuf.addPartialResult(results)
|
||||
|
||||
colName := t.(*queryTask).query.CollectionName
|
||||
log.Debug("process retrieve result", zap.String("role", typeutil.ProxyRole), zap.String("collection", colName), zap.Int64("reqID", reqID), zap.Int("answer cnt", len(queryResultBufs[reqID].resultBuf)))
|
||||
|
||||
if resultBuf.readyToReduce() {
|
||||
log.Debug("process retrieve result, ready to reduce", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
|
||||
queryResultBufFlags.Set(reqID, true)
|
||||
st.resultBuf <- resultBuf.resultBuf
|
||||
delete(queryResultBufs, reqID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sched.ctx.Done():
|
||||
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "context done"))
|
||||
return
|
||||
case sr, ok := <-sched.searchResultCh:
|
||||
if !ok {
|
||||
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "search result channel closed"))
|
||||
return
|
||||
}
|
||||
if err := processSearchResult(sr); err != nil {
|
||||
log.Warn("failed to process search result", zap.Error(err))
|
||||
}
|
||||
case rr, ok := <-sched.retrieveResultCh:
|
||||
if !ok {
|
||||
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "retrieve result channel closed"))
|
||||
return
|
||||
}
|
||||
if err := processRetrieveResult(rr); err != nil {
|
||||
log.Warn("failed to process retrieve result", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sched *taskScheduler) collectResultLoop() {
|
||||
defer sched.wg.Done()
|
||||
|
||||
|
@ -848,7 +1028,8 @@ func (sched *taskScheduler) Start() error {
|
|||
go sched.queryLoop()
|
||||
|
||||
sched.wg.Add(1)
|
||||
go sched.collectResultLoop()
|
||||
// go sched.collectResultLoop()
|
||||
go sched.collectionResultLoopV2()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
type mockProxy struct {
|
||||
}
|
||||
|
||||
func (m *mockProxy) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) Register() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newMockProxy() types.Proxy {
|
||||
return &mockProxy{}
|
||||
}
|
||||
|
||||
func mockProxyCreator() proxyCreatorFunc {
|
||||
return func(ctx context.Context, addr string) (types.Proxy, error) {
|
||||
return newMockProxy(), nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
var errDisposed = errors.New("client is disposed")
|
||||
|
||||
type NodeInfo struct {
|
||||
NodeID int64
|
||||
Address string
|
||||
}
|
||||
|
||||
type proxyCreatorFunc func(ctx context.Context, addr string) (types.Proxy, error)
|
||||
|
||||
type Session struct {
|
||||
sync.Mutex
|
||||
info *NodeInfo
|
||||
client types.Proxy
|
||||
clientCreator proxyCreatorFunc
|
||||
isDisposed bool
|
||||
}
|
||||
|
||||
func NewSession(info *NodeInfo, creator proxyCreatorFunc) *Session {
|
||||
return &Session{
|
||||
info: info,
|
||||
clientCreator: creator,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Session) GetOrCreateClient(ctx context.Context) (types.Proxy, error) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.isDisposed {
|
||||
return nil, errDisposed
|
||||
}
|
||||
|
||||
if n.client != nil {
|
||||
return n.client, nil
|
||||
}
|
||||
|
||||
if n.clientCreator == nil {
|
||||
return nil, fmt.Errorf("unable to create client for %s because of a nil client creator", n.info.Address)
|
||||
}
|
||||
|
||||
err := n.initClient(ctx)
|
||||
return n.client, err
|
||||
}
|
||||
|
||||
func (n *Session) initClient(ctx context.Context) (err error) {
|
||||
if n.client, err = n.clientCreator(ctx, n.info.Address); err != nil {
|
||||
return
|
||||
}
|
||||
if err = n.client.Init(); err != nil {
|
||||
return
|
||||
}
|
||||
return n.client.Start()
|
||||
}
|
||||
|
||||
func (n *Session) Dispose() {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
if n.client != nil {
|
||||
n.client.Stop()
|
||||
n.client = nil
|
||||
}
|
||||
n.isDisposed = true
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
|
||||
grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
sessions struct {
|
||||
sync.RWMutex
|
||||
data map[int64]*Session
|
||||
}
|
||||
// sessions sync.Map // UniqueID -> Session
|
||||
sessionCreator proxyCreatorFunc
|
||||
}
|
||||
|
||||
// SessionOpt provides a way to set params in SessionManager
|
||||
type SessionOpt func(c *SessionManager)
|
||||
|
||||
func withSessionCreator(creator proxyCreatorFunc) SessionOpt {
|
||||
return func(c *SessionManager) { c.sessionCreator = creator }
|
||||
}
|
||||
|
||||
func defaultSessionCreator() proxyCreatorFunc {
|
||||
return func(ctx context.Context, addr string) (types.Proxy, error) {
|
||||
return grpcproxyclient.NewClient(ctx, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new SessionManager
|
||||
func NewSessionManager(options ...SessionOpt) *SessionManager {
|
||||
m := &SessionManager{
|
||||
sessions: struct {
|
||||
sync.RWMutex
|
||||
data map[int64]*Session
|
||||
}{data: make(map[int64]*Session)},
|
||||
sessionCreator: defaultSessionCreator(),
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// AddSession creates a new session
|
||||
func (c *SessionManager) AddSession(node *NodeInfo) {
|
||||
c.sessions.Lock()
|
||||
defer c.sessions.Unlock()
|
||||
|
||||
session := NewSession(node, c.sessionCreator)
|
||||
c.sessions.data[node.NodeID] = session
|
||||
}
|
||||
|
||||
func (c *SessionManager) Startup(nodes []*NodeInfo) {
|
||||
for _, node := range nodes {
|
||||
c.AddSession(node)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteSession removes the node session
|
||||
func (c *SessionManager) DeleteSession(node *NodeInfo) {
|
||||
c.sessions.Lock()
|
||||
defer c.sessions.Unlock()
|
||||
|
||||
if session, ok := c.sessions.data[node.NodeID]; ok {
|
||||
session.Dispose()
|
||||
delete(c.sessions.data, node.NodeID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessions gets all node sessions
|
||||
func (c *SessionManager) GetSessions() []*Session {
|
||||
c.sessions.RLock()
|
||||
defer c.sessions.RUnlock()
|
||||
|
||||
ret := make([]*Session, 0, len(c.sessions.data))
|
||||
for _, s := range c.sessions.data {
|
||||
ret = append(ret, s)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *SessionManager) SendSearchResult(ctx context.Context, nodeID UniqueID, result *internalpb.SearchResults) error {
|
||||
cli, err := c.getClient(ctx, nodeID)
|
||||
if err != nil {
|
||||
log.Warn("failed to send search result, cannot get client", zap.Int64("nodeID", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := cli.SendSearchResult(ctx, result)
|
||||
if err := funcutil.VerifyResponse(resp, err); err != nil {
|
||||
log.Warn("failed to send search result", zap.Int64("node", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("success to send search result", zap.Int64("node", nodeID), zap.Any("base", result.Base))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SessionManager) SendRetrieveResult(ctx context.Context, nodeID UniqueID, result *internalpb.RetrieveResults) error {
|
||||
cli, err := c.getClient(ctx, nodeID)
|
||||
if err != nil {
|
||||
log.Warn("failed to send retrieve result, cannot get client", zap.Int64("nodeID", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := cli.SendRetrieveResult(ctx, result)
|
||||
if err := funcutil.VerifyResponse(resp, err); err != nil {
|
||||
log.Warn("failed to send retrieve result", zap.Int64("node", nodeID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("success to send retrieve result", zap.Int64("node", nodeID), zap.Any("base", result.Base))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.Proxy, error) {
|
||||
c.sessions.RLock()
|
||||
session, ok := c.sessions.data[nodeID]
|
||||
c.sessions.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can not find session of node %d", nodeID)
|
||||
}
|
||||
|
||||
return session.GetOrCreateClient(ctx)
|
||||
}
|
||||
|
||||
// Close release sessions
|
||||
func (c *SessionManager) Close() {
|
||||
c.sessions.Lock()
|
||||
defer c.sessions.Unlock()
|
||||
|
||||
for _, s := range c.sessions.data {
|
||||
s.Dispose()
|
||||
}
|
||||
c.sessions.data = nil
|
||||
}
|
|
@ -78,8 +78,9 @@ type queryCollection struct {
|
|||
serviceableTimeMutex sync.RWMutex // guards serviceableTime
|
||||
serviceableTime Timestamp
|
||||
|
||||
queryMsgStream msgstream.MsgStream
|
||||
queryResultMsgStream msgstream.MsgStream
|
||||
queryMsgStream msgstream.MsgStream
|
||||
// queryResultMsgStream msgstream.MsgStream
|
||||
sessionManager *SessionManager
|
||||
|
||||
localChunkManager storage.ChunkManager
|
||||
remoteChunkManager storage.ChunkManager
|
||||
|
@ -89,6 +90,14 @@ type queryCollection struct {
|
|||
globalSegmentManager *globalSealedSegmentManager
|
||||
}
|
||||
|
||||
type qcOpt func(*queryCollection)
|
||||
|
||||
func qcOptWithSessionManager(s *SessionManager) qcOpt {
|
||||
return func(qc *queryCollection) {
|
||||
qc.sessionManager = s
|
||||
}
|
||||
}
|
||||
|
||||
func newQueryCollection(releaseCtx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
collectionID UniqueID,
|
||||
|
@ -98,12 +107,13 @@ func newQueryCollection(releaseCtx context.Context,
|
|||
localChunkManager storage.ChunkManager,
|
||||
remoteChunkManager storage.ChunkManager,
|
||||
localCacheEnabled bool,
|
||||
opts ...qcOpt,
|
||||
) (*queryCollection, error) {
|
||||
|
||||
unsolvedMsg := make([]queryMsg, 0)
|
||||
|
||||
queryStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
// queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
|
||||
|
||||
condMu := sync.Mutex{}
|
||||
|
||||
|
@ -121,8 +131,8 @@ func newQueryCollection(releaseCtx context.Context,
|
|||
|
||||
unsolvedMsg: unsolvedMsg,
|
||||
|
||||
queryMsgStream: queryStream,
|
||||
queryResultMsgStream: queryResultStream,
|
||||
queryMsgStream: queryStream,
|
||||
// queryResultMsgStream: queryResultStream,
|
||||
|
||||
localChunkManager: localChunkManager,
|
||||
remoteChunkManager: remoteChunkManager,
|
||||
|
@ -130,6 +140,10 @@ func newQueryCollection(releaseCtx context.Context,
|
|||
globalSegmentManager: newGlobalSealedSegmentManager(collectionID),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(qc)
|
||||
}
|
||||
|
||||
err := qc.registerCollectionTSafe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -139,7 +153,7 @@ func newQueryCollection(releaseCtx context.Context,
|
|||
|
||||
func (q *queryCollection) start() {
|
||||
go q.queryMsgStream.Start()
|
||||
go q.queryResultMsgStream.Start()
|
||||
// go q.queryResultMsgStream.Start()
|
||||
go q.consumeQuery()
|
||||
go q.doUnsolvedQueryMsg()
|
||||
}
|
||||
|
@ -148,9 +162,9 @@ func (q *queryCollection) close() {
|
|||
if q.queryMsgStream != nil {
|
||||
q.queryMsgStream.Close()
|
||||
}
|
||||
if q.queryResultMsgStream != nil {
|
||||
q.queryResultMsgStream.Close()
|
||||
}
|
||||
// if q.queryResultMsgStream != nil {
|
||||
// q.queryResultMsgStream.Close()
|
||||
// }
|
||||
q.globalSegmentManager.close()
|
||||
}
|
||||
|
||||
|
@ -1089,7 +1103,7 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
zap.Any("vChannels", collection.getVChannels()),
|
||||
zap.Any("sealedSegmentSearched", sealedSegmentSearched),
|
||||
)
|
||||
err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID)
|
||||
err = q.publishSearchResult(&searchResultMsg.SearchResults, searchMsg.Base.SourceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1196,7 +1210,7 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||
// fmt.Println(testHits.IDs)
|
||||
// fmt.Println(testHits.Scores)
|
||||
//}
|
||||
err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID)
|
||||
err = q.publishSearchResult(&searchResultMsg.SearchResults, searchMsg.Base.SourceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1309,7 +1323,7 @@ func (q *queryCollection) retrieve(msg queryMsg) error {
|
|||
},
|
||||
}
|
||||
|
||||
err = q.publishQueryResult(retrieveResultMsg, retrieveMsg.CollectionID)
|
||||
err = q.publishRetrieveResult(&retrieveResultMsg.RetrieveResults, retrieveMsg.Base.SourceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1377,31 +1391,28 @@ func mergeRetrieveResults(retrieveResults []*segcorepb.RetrieveResults) (*segcor
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishQueryResult(msg msgstream.TsMsg, collectionID UniqueID) error {
|
||||
span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
defer span.Finish()
|
||||
msg.SetTraceCtx(ctx)
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err := q.queryResultMsgStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
}
|
||||
|
||||
return err
|
||||
func (q *queryCollection) publishSearchResultWithCtx(ctx context.Context, result *internalpb.SearchResults, nodeID UniqueID) error {
|
||||
return q.sessionManager.SendSearchResult(ctx, nodeID, result)
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishFailedQueryResult(msg msgstream.TsMsg, errMsg string) error {
|
||||
msgType := msg.Type()
|
||||
span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
defer span.Finish()
|
||||
msg.SetTraceCtx(ctx)
|
||||
msgPack := msgstream.MsgPack{}
|
||||
func (q *queryCollection) publishSearchResult(result *internalpb.SearchResults, nodeID UniqueID) error {
|
||||
return q.publishSearchResultWithCtx(q.releaseCtx, result, nodeID)
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishRetrieveResultWithCtx(ctx context.Context, result *internalpb.RetrieveResults, nodeID UniqueID) error {
|
||||
return q.sessionManager.SendRetrieveResult(ctx, nodeID, result)
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishRetrieveResult(result *internalpb.RetrieveResults, nodeID UniqueID) error {
|
||||
return q.publishRetrieveResultWithCtx(q.releaseCtx, result, nodeID)
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishFailedQueryResultWithCtx(ctx context.Context, msg msgstream.TsMsg, errMsg string) error {
|
||||
msgType := msg.Type()
|
||||
span, traceCtx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
defer span.Finish()
|
||||
msg.SetTraceCtx(traceCtx)
|
||||
|
||||
resultChannelInt := 0
|
||||
baseMsg := msgstream.BaseMsg{
|
||||
HashValues: []uint32{uint32(resultChannelInt)},
|
||||
}
|
||||
baseResult := &commonpb.MsgBase{
|
||||
MsgID: msg.ID(),
|
||||
Timestamp: msg.BeginTs(),
|
||||
|
@ -1412,32 +1423,92 @@ func (q *queryCollection) publishFailedQueryResult(msg msgstream.TsMsg, errMsg s
|
|||
case commonpb.MsgType_Retrieve:
|
||||
retrieveMsg := msg.(*msgstream.RetrieveMsg)
|
||||
baseResult.MsgType = commonpb.MsgType_RetrieveResult
|
||||
retrieveResultMsg := &msgstream.RetrieveResultMsg{
|
||||
BaseMsg: baseMsg,
|
||||
RetrieveResults: internalpb.RetrieveResults{
|
||||
Base: baseResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
|
||||
ResultChannelID: retrieveMsg.ResultChannelID,
|
||||
Ids: nil,
|
||||
FieldsData: nil,
|
||||
},
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, retrieveResultMsg)
|
||||
return q.publishRetrieveResult(&internalpb.RetrieveResults{
|
||||
Base: baseResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
|
||||
ResultChannelID: retrieveMsg.ResultChannelID,
|
||||
Ids: nil,
|
||||
FieldsData: nil,
|
||||
}, msg.SourceID())
|
||||
case commonpb.MsgType_Search:
|
||||
searchMsg := msg.(*msgstream.SearchMsg)
|
||||
baseResult.MsgType = commonpb.MsgType_SearchResult
|
||||
searchResultMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: baseMsg,
|
||||
SearchResults: internalpb.SearchResults{
|
||||
Base: baseResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
},
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, searchResultMsg)
|
||||
return q.publishSearchResultWithCtx(ctx, &internalpb.SearchResults{
|
||||
Base: baseResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
}, msg.SourceID())
|
||||
default:
|
||||
return fmt.Errorf("publish invalid msgType %d", msgType)
|
||||
}
|
||||
|
||||
return q.queryResultMsgStream.Produce(&msgPack)
|
||||
}
|
||||
|
||||
func (q *queryCollection) publishFailedQueryResult(msg msgstream.TsMsg, errMsg string) error {
|
||||
return q.publishFailedQueryResultWithCtx(q.releaseCtx, msg, errMsg)
|
||||
}
|
||||
|
||||
// func (q *queryCollection) publishQueryResult(msg msgstream.TsMsg, collectionID UniqueID) error {
|
||||
// span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
// defer span.Finish()
|
||||
// msg.SetTraceCtx(ctx)
|
||||
// msgPack := msgstream.MsgPack{}
|
||||
// msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
// err := q.queryResultMsgStream.Produce(&msgPack)
|
||||
// if err != nil {
|
||||
// log.Error(err.Error())
|
||||
// }
|
||||
//
|
||||
// return err
|
||||
// }
|
||||
|
||||
// func (q *queryCollection) publishFailedQueryResult(msg msgstream.TsMsg, errMsg string) error {
|
||||
// msgType := msg.Type()
|
||||
// span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
|
||||
// defer span.Finish()
|
||||
// msg.SetTraceCtx(ctx)
|
||||
// msgPack := msgstream.MsgPack{}
|
||||
//
|
||||
// resultChannelInt := 0
|
||||
// baseMsg := msgstream.BaseMsg{
|
||||
// HashValues: []uint32{uint32(resultChannelInt)},
|
||||
// }
|
||||
// baseResult := &commonpb.MsgBase{
|
||||
// MsgID: msg.ID(),
|
||||
// Timestamp: msg.BeginTs(),
|
||||
// SourceID: msg.SourceID(),
|
||||
// }
|
||||
//
|
||||
// switch msgType {
|
||||
// case commonpb.MsgType_Retrieve:
|
||||
// retrieveMsg := msg.(*msgstream.RetrieveMsg)
|
||||
// baseResult.MsgType = commonpb.MsgType_RetrieveResult
|
||||
// retrieveResultMsg := &msgstream.RetrieveResultMsg{
|
||||
// BaseMsg: baseMsg,
|
||||
// RetrieveResults: internalpb.RetrieveResults{
|
||||
// Base: baseResult,
|
||||
// Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
|
||||
// ResultChannelID: retrieveMsg.ResultChannelID,
|
||||
// Ids: nil,
|
||||
// FieldsData: nil,
|
||||
// },
|
||||
// }
|
||||
// msgPack.Msgs = append(msgPack.Msgs, retrieveResultMsg)
|
||||
// case commonpb.MsgType_Search:
|
||||
// searchMsg := msg.(*msgstream.SearchMsg)
|
||||
// baseResult.MsgType = commonpb.MsgType_SearchResult
|
||||
// searchResultMsg := &msgstream.SearchResultMsg{
|
||||
// BaseMsg: baseMsg,
|
||||
// SearchResults: internalpb.SearchResults{
|
||||
// Base: baseResult,
|
||||
// Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
|
||||
// ResultChannelID: searchMsg.ResultChannelID,
|
||||
// },
|
||||
// }
|
||||
// msgPack.Msgs = append(msgPack.Msgs, searchResultMsg)
|
||||
// default:
|
||||
// return fmt.Errorf("publish invalid msgType %d", msgType)
|
||||
// }
|
||||
//
|
||||
// return q.queryResultMsgStream.Produce(&msgPack)
|
||||
// }
|
||||
//
|
||||
|
|
|
@ -177,8 +177,14 @@ func TestQueryCollection_withoutVChannel(t *testing.T) {
|
|||
queryCollection, err := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
producerChannels := []string{"testResultChannel"}
|
||||
queryCollection.queryResultMsgStream.AsProducer(producerChannels)
|
||||
// producerChannels := []string{"testResultChannel"}
|
||||
// queryCollection.queryResultMsgStream.AsProducer(producerChannels)
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
dim := 2
|
||||
// generate search rawData
|
||||
|
@ -269,6 +275,13 @@ func TestQueryCollection_consumeQuery(t *testing.T) {
|
|||
queryCollection.queryMsgStream.AsConsumer([]Channel{queryChannel}, defaultSubName)
|
||||
queryCollection.queryMsgStream.Start()
|
||||
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
go queryCollection.consumeQuery()
|
||||
|
||||
producer, err := genQueryMsgStream(ctx)
|
||||
|
@ -574,6 +587,13 @@ func TestQueryCollection_doUnsolvedQueryMsg(t *testing.T) {
|
|||
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
timestamp := Timestamp(1000)
|
||||
err = updateTSafe(queryCollection, timestamp)
|
||||
assert.NoError(t, err)
|
||||
|
@ -591,6 +611,13 @@ func TestQueryCollection_doUnsolvedQueryMsg(t *testing.T) {
|
|||
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
timestamp := Timestamp(1000)
|
||||
err = updateTSafe(queryCollection, timestamp)
|
||||
assert.NoError(t, err)
|
||||
|
@ -612,9 +639,15 @@ func TestQueryCollection_search(t *testing.T) {
|
|||
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
queryChannel := genQueryChannel()
|
||||
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
queryCollection.queryResultMsgStream.Start()
|
||||
// queryChannel := genQueryChannel()
|
||||
// queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
// queryCollection.queryResultMsgStream.Start()
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 0,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
err = queryCollection.streaming.replica.removeSegment(defaultSegmentID)
|
||||
assert.NoError(t, err)
|
||||
|
@ -635,9 +668,15 @@ func TestQueryCollection_receive(t *testing.T) {
|
|||
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
queryChannel := genQueryChannel()
|
||||
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
queryCollection.queryResultMsgStream.Start()
|
||||
// queryChannel := genQueryChannel()
|
||||
// queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
// queryCollection.queryResultMsgStream.Start()
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 0,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
vecCM, err := genVectorChunkManager(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
@ -758,9 +797,15 @@ func TestQueryCollection_search_while_release(t *testing.T) {
|
|||
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
queryChannel := genQueryChannel()
|
||||
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
queryCollection.queryResultMsgStream.Start()
|
||||
// queryChannel := genQueryChannel()
|
||||
// queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
// queryCollection.queryResultMsgStream.Start()
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
msg, err := genSimpleSearchMsg()
|
||||
assert.NoError(t, err)
|
||||
|
@ -797,9 +842,15 @@ func TestQueryCollection_search_while_release(t *testing.T) {
|
|||
queryCollection, err := genSimpleQueryCollection(ctx, cancel)
|
||||
assert.NoError(t, err)
|
||||
|
||||
queryChannel := genQueryChannel()
|
||||
queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
queryCollection.queryResultMsgStream.Start()
|
||||
// queryChannel := genQueryChannel()
|
||||
// queryCollection.queryResultMsgStream.AsProducer([]Channel{queryChannel})
|
||||
// queryCollection.queryResultMsgStream.Start()
|
||||
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
|
||||
sessionManager.AddSession(&NodeInfo{
|
||||
NodeID: 1,
|
||||
Address: "",
|
||||
})
|
||||
queryCollection.sessionManager = sessionManager
|
||||
|
||||
msg, err := genSimpleSearchMsg()
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -78,6 +78,8 @@ type QueryNode struct {
|
|||
queryNodeLoopCtx context.Context
|
||||
queryNodeLoopCancel context.CancelFunc
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
stateCode atomic.Value
|
||||
|
||||
//call once
|
||||
|
@ -110,7 +112,9 @@ type QueryNode struct {
|
|||
msFactory msgstream.Factory
|
||||
scheduler *taskScheduler
|
||||
|
||||
session *sessionutil.Session
|
||||
session *sessionutil.Session
|
||||
eventCh <-chan *sessionutil.SessionEvent
|
||||
sessionManager *SessionManager
|
||||
|
||||
minioKV kv.BaseKV // minio minioKV
|
||||
etcdKV *etcdkv.EtcdKV
|
||||
|
@ -180,6 +184,72 @@ func (node *QueryNode) InitSegcore() {
|
|||
C.free(unsafe.Pointer(cSimdType))
|
||||
}
|
||||
|
||||
func (node *QueryNode) initServiceDiscovery() error {
|
||||
if node.session == nil {
|
||||
return errors.New("session is nil")
|
||||
}
|
||||
|
||||
sessions, rev, err := node.session.GetSessions(typeutil.ProxyRole)
|
||||
if err != nil {
|
||||
log.Warn("QueryNode failed to init service discovery", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("QueryNode success to get Proxy sessions", zap.Any("sessions", sessions))
|
||||
|
||||
nodes := make([]*NodeInfo, 0, len(sessions))
|
||||
for _, session := range sessions {
|
||||
info := &NodeInfo{
|
||||
NodeID: session.ServerID,
|
||||
Address: session.Address,
|
||||
}
|
||||
nodes = append(nodes, info)
|
||||
}
|
||||
|
||||
node.sessionManager.Startup(nodes)
|
||||
|
||||
node.eventCh = node.session.WatchServices(typeutil.ProxyRole, rev+1, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) watchService(ctx context.Context) {
|
||||
defer node.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("watch service shutdown")
|
||||
return
|
||||
case event, ok := <-node.eventCh:
|
||||
if !ok {
|
||||
//TODO add retry logic
|
||||
return
|
||||
}
|
||||
if err := node.handleSessionEvent(ctx, event); err != nil {
|
||||
log.Warn("handleSessionEvent", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (node *QueryNode) handleSessionEvent(ctx context.Context, event *sessionutil.SessionEvent) error {
|
||||
if event == nil {
|
||||
return nil
|
||||
}
|
||||
info := &NodeInfo{
|
||||
NodeID: event.Session.ServerID,
|
||||
Address: event.Session.Address,
|
||||
}
|
||||
switch event.EventType {
|
||||
case sessionutil.SessionAddEvent:
|
||||
node.sessionManager.AddSession(info)
|
||||
case sessionutil.SessionDelEvent:
|
||||
node.sessionManager.DeleteSession(info)
|
||||
default:
|
||||
log.Warn("receive unknown service event type",
|
||||
zap.Any("type", event.EventType))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init function init historical and streaming module to manage segments
|
||||
func (node *QueryNode) Init() error {
|
||||
var initError error = nil
|
||||
|
@ -235,6 +305,9 @@ func (node *QueryNode) Init() error {
|
|||
return
|
||||
}
|
||||
|
||||
// TODO: add session creator to node
|
||||
node.sessionManager = NewSessionManager(withSessionCreator(defaultSessionCreator()))
|
||||
|
||||
log.Debug("query node init successfully",
|
||||
zap.Any("queryNodeID", Params.QueryNodeCfg.QueryNodeID),
|
||||
zap.Any("IP", Params.QueryNodeCfg.QueryNodeIP),
|
||||
|
@ -262,7 +335,8 @@ func (node *QueryNode) Start() error {
|
|||
node.queryService = newQueryService(node.queryNodeLoopCtx,
|
||||
node.historical,
|
||||
node.streaming,
|
||||
node.msFactory)
|
||||
node.msFactory,
|
||||
qsOptWithSessionManager(node.sessionManager))
|
||||
|
||||
// start task scheduler
|
||||
go node.scheduler.Start()
|
||||
|
@ -271,6 +345,14 @@ func (node *QueryNode) Start() error {
|
|||
go node.watchChangeInfo()
|
||||
go node.statsService.start()
|
||||
|
||||
// watch proxy
|
||||
if err := node.initServiceDiscovery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node.wg.Add(1)
|
||||
go node.watchService(node.queryNodeLoopCtx)
|
||||
|
||||
Params.QueryNodeCfg.CreatedTime = time.Now()
|
||||
Params.QueryNodeCfg.UpdatedTime = time.Now()
|
||||
|
||||
|
@ -305,6 +387,7 @@ func (node *QueryNode) Stop() error {
|
|||
node.statsService.close()
|
||||
}
|
||||
node.session.Revoke(time.Second)
|
||||
node.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -44,15 +44,27 @@ type queryService struct {
|
|||
|
||||
factory msgstream.Factory
|
||||
|
||||
sessionManager *SessionManager
|
||||
|
||||
localChunkManager storage.ChunkManager
|
||||
remoteChunkManager storage.ChunkManager
|
||||
localCacheEnabled bool
|
||||
}
|
||||
|
||||
type qsOpt func(*queryService)
|
||||
|
||||
func qsOptWithSessionManager(s *SessionManager) qsOpt {
|
||||
return func(qs *queryService) {
|
||||
qs.sessionManager = s
|
||||
}
|
||||
}
|
||||
|
||||
func newQueryService(ctx context.Context,
|
||||
historical *historical,
|
||||
streaming *streaming,
|
||||
factory msgstream.Factory) *queryService {
|
||||
factory msgstream.Factory,
|
||||
opts ...qsOpt,
|
||||
) *queryService {
|
||||
|
||||
queryServiceCtx, queryServiceCancel := context.WithCancel(ctx)
|
||||
|
||||
|
@ -81,7 +93,7 @@ func newQueryService(ctx context.Context,
|
|||
}
|
||||
remoteChunkManager := storage.NewMinioChunkManager(client)
|
||||
|
||||
return &queryService{
|
||||
qs := &queryService{
|
||||
ctx: queryServiceCtx,
|
||||
cancel: queryServiceCancel,
|
||||
|
||||
|
@ -96,6 +108,12 @@ func newQueryService(ctx context.Context,
|
|||
remoteChunkManager: remoteChunkManager,
|
||||
localCacheEnabled: localCacheEnabled,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(qs)
|
||||
}
|
||||
|
||||
return qs
|
||||
}
|
||||
|
||||
func (q *queryService) close() {
|
||||
|
@ -129,6 +147,7 @@ func (q *queryService) addQueryCollection(collectionID UniqueID) error {
|
|||
q.localChunkManager,
|
||||
q.remoteChunkManager,
|
||||
q.localCacheEnabled,
|
||||
qcOptWithSessionManager(q.sessionManager),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -182,9 +182,9 @@ func (r *addQueryChannelTask) Execute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// add result channel
|
||||
producerChannels := []string{r.req.QueryResultChannel}
|
||||
sc.queryResultMsgStream.AsProducer(producerChannels)
|
||||
log.Debug("QueryNode AsProducer", zap.Strings("channels", producerChannels))
|
||||
// producerChannels := []string{r.req.QueryResultChannel}
|
||||
// sc.queryResultMsgStream.AsProducer(producerChannels)
|
||||
// log.Debug("QueryNode AsProducer", zap.Strings("channels", producerChannels))
|
||||
|
||||
// init global sealed segments
|
||||
for _, segment := range r.req.GlobalSealedSegments {
|
||||
|
|
|
@ -647,6 +647,9 @@ type Proxy interface {
|
|||
//
|
||||
// error is returned only when some communication issue occurs.
|
||||
ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error)
|
||||
|
||||
SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error)
|
||||
SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error)
|
||||
}
|
||||
|
||||
// ProxyComponent defines the interface of proxy component.
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package funcutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
)
|
||||
|
||||
// errors for VerifyResponse
|
||||
var errNilResponse = errors.New("response is nil")
|
||||
var errNilStatusResponse = errors.New("response has nil status")
|
||||
var errUnknownResponseType = errors.New("unknown response type")
|
||||
|
||||
// Response response interface for verification
|
||||
type Response interface {
|
||||
GetStatus() *commonpb.Status
|
||||
}
|
||||
|
||||
// VerifyResponse verify grpc Response 1. check error is nil 2. check response.GetStatus() with status success
|
||||
func VerifyResponse(response interface{}, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if response == nil {
|
||||
return errNilResponse
|
||||
}
|
||||
switch resp := response.(type) {
|
||||
case Response:
|
||||
// note that resp will not be nil here, since it's still an interface
|
||||
if resp.GetStatus() == nil {
|
||||
return errNilStatusResponse
|
||||
}
|
||||
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return errors.New(resp.GetStatus().GetReason())
|
||||
}
|
||||
case *commonpb.Status:
|
||||
if resp == nil {
|
||||
return errNilResponse
|
||||
}
|
||||
if resp.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return errors.New(resp.GetReason())
|
||||
}
|
||||
default:
|
||||
return errUnknownResponseType
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -50,3 +50,11 @@ func (m *ProxyClient) GetDdChannel(ctx context.Context, in *internalpb.GetDdChan
|
|||
func (m *ProxyClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *ProxyClient) SendSearchResult(ctx context.Context, in *internalpb.SearchResults, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
||||
func (m *ProxyClient) SendRetrieveResult(ctx context.Context, in *internalpb.RetrieveResults, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.Err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue