mirror of https://github.com/milvus-io/milvus.git
Check target node ID for query/search (#20976)
Signed-off-by: yah01 <yang.cen@zilliz.com> Signed-off-by: yah01 <yang.cen@zilliz.com>pull/20984/head
parent
90ca7e1e2d
commit
d14271f30c
|
@ -386,8 +386,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
|
||||
retrieveReq := typeutil.Clone(t.RetrieveRequest)
|
||||
retrieveReq.GetBase().TargetID = nodeID
|
||||
req := &querypb.QueryRequest{
|
||||
Req: t.RetrieveRequest,
|
||||
Req: retrieveReq,
|
||||
DmlChannels: channelIDs,
|
||||
Scope: querypb.DataScope_All,
|
||||
}
|
||||
|
|
|
@ -476,8 +476,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
|
||||
searchReq := typeutil.Clone(t.SearchRequest)
|
||||
searchReq.GetBase().TargetID = nodeID
|
||||
req := &querypb.SearchRequest{
|
||||
Req: t.SearchRequest,
|
||||
Req: searchReq,
|
||||
DmlChannels: channelIDs,
|
||||
Scope: querypb.DataScope_All,
|
||||
}
|
||||
|
|
|
@ -709,6 +709,15 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
|
||||
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
|
||||
|
||||
if req.GetReq().GetBase().GetTargetID() != node.session.ServerID {
|
||||
return &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
|
||||
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), node.session.ServerID),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
failRet := &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -1065,6 +1074,15 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
|
|||
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
|
||||
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
|
||||
|
||||
if req.GetReq().GetBase().GetTargetID() != node.session.ServerID {
|
||||
return &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
|
||||
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), node.session.ServerID),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
failRet := &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
|
|
@ -680,6 +680,15 @@ func TestImpl_Search(t *testing.T) {
|
|||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
req.GetBase().TargetID = -1
|
||||
ret, err := node.Search(ctx, &queryPb.SearchRequest{
|
||||
Req: req,
|
||||
FromShardLeader: false,
|
||||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func TestImpl_searchWithDmlChannel(t *testing.T) {
|
||||
|
@ -790,6 +799,15 @@ func TestImpl_Query(t *testing.T) {
|
|||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
req.GetBase().TargetID = -1
|
||||
ret, err := node.Query(ctx, &queryPb.QueryRequest{
|
||||
Req: req,
|
||||
FromShardLeader: false,
|
||||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func TestImpl_queryWithDmlChannel(t *testing.T) {
|
||||
|
|
|
@ -956,8 +956,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
|
|||
|
||||
// dispatch request to followers
|
||||
for nodeID, segments := range segAllocs {
|
||||
internalReq := typeutil.Clone(req.GetReq())
|
||||
internalReq.GetBase().TargetID = nodeID
|
||||
nodeReq := &querypb.SearchRequest{
|
||||
Req: req.Req,
|
||||
Req: internalReq,
|
||||
DmlChannels: req.DmlChannels,
|
||||
FromShardLeader: true,
|
||||
Scope: querypb.DataScope_Historical,
|
||||
|
@ -1041,8 +1043,10 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi
|
|||
|
||||
// dispatch request to followers
|
||||
for nodeID, segments := range segAllocs {
|
||||
internalReq := typeutil.Clone(req.GetReq())
|
||||
internalReq.GetBase().TargetID = nodeID
|
||||
nodeReq := &querypb.QueryRequest{
|
||||
Req: req.Req,
|
||||
Req: internalReq,
|
||||
FromShardLeader: true,
|
||||
SegmentIDs: segments,
|
||||
Scope: querypb.DataScope_Historical,
|
||||
|
|
|
@ -1164,6 +1164,9 @@ func TestShardCluster_Search(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
result, err := sc.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, streamingDoNothing)
|
||||
assert.NoError(t, err)
|
||||
|
@ -1215,6 +1218,9 @@ func TestShardCluster_Search(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
_, err := sc.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, func(ctx context.Context) error { return errors.New("mocked") })
|
||||
assert.Error(t, err)
|
||||
|
@ -1273,6 +1279,9 @@ func TestShardCluster_Search(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
_, err := sc.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, streamingDoNothing)
|
||||
assert.Error(t, err)
|
||||
|
@ -1325,6 +1334,9 @@ func TestShardCluster_Search(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
_, err := sc.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, streamingDoNothing)
|
||||
assert.Error(t, err)
|
||||
|
@ -1385,6 +1397,9 @@ func TestShardCluster_Query(t *testing.T) {
|
|||
require.EqualValues(t, unavailable, sc.state.Load())
|
||||
|
||||
_, err := sc.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, streamingDoNothing)
|
||||
assert.Error(t, err)
|
||||
|
@ -1398,6 +1413,9 @@ func TestShardCluster_Query(t *testing.T) {
|
|||
sc.SetupFirstVersion()
|
||||
|
||||
_, err := sc.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName + "_suffix"},
|
||||
}, streamingDoNothing)
|
||||
assert.Error(t, err)
|
||||
|
@ -1447,6 +1465,9 @@ func TestShardCluster_Query(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
result, err := sc.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, streamingDoNothing)
|
||||
assert.NoError(t, err)
|
||||
|
@ -1497,6 +1518,9 @@ func TestShardCluster_Query(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
_, err := sc.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, func(ctx context.Context) error { return errors.New("mocked") })
|
||||
assert.Error(t, err)
|
||||
|
@ -1555,6 +1579,9 @@ func TestShardCluster_Query(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
_, err := sc.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, streamingDoNothing)
|
||||
assert.Error(t, err)
|
||||
|
@ -1608,6 +1635,9 @@ func TestShardCluster_Query(t *testing.T) {
|
|||
require.EqualValues(t, available, sc.state.Load())
|
||||
|
||||
_, err := sc.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{
|
||||
Base: &commonpb.MsgBase{},
|
||||
},
|
||||
DmlChannels: []string{vchannelName},
|
||||
}, streamingDoNothing)
|
||||
assert.Error(t, err)
|
||||
|
|
Loading…
Reference in New Issue