Check target node ID for query/search (#20976)

Signed-off-by: yah01 <yang.cen@zilliz.com>

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/20984/head
yah01 2022-12-05 15:09:20 +08:00 committed by GitHub
parent 90ca7e1e2d
commit d14271f30c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 78 additions and 4 deletions

View File

@ -386,8 +386,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
}
func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
retrieveReq := typeutil.Clone(t.RetrieveRequest)
retrieveReq.GetBase().TargetID = nodeID
req := &querypb.QueryRequest{
Req: t.RetrieveRequest,
Req: retrieveReq,
DmlChannels: channelIDs,
Scope: querypb.DataScope_All,
}

View File

@ -476,8 +476,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
searchReq := typeutil.Clone(t.SearchRequest)
searchReq.GetBase().TargetID = nodeID
req := &querypb.SearchRequest{
Req: t.SearchRequest,
Req: searchReq,
DmlChannels: channelIDs,
Scope: querypb.DataScope_All,
}

View File

@ -709,6 +709,15 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
if req.GetReq().GetBase().GetTargetID() != node.session.ServerID {
return &internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), node.session.ServerID),
},
}, nil
}
failRet := &internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -1065,6 +1074,15 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
if req.GetReq().GetBase().GetTargetID() != node.session.ServerID {
return &internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), node.session.ServerID),
},
}, nil
}
failRet := &internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,

View File

@ -680,6 +680,15 @@ func TestImpl_Search(t *testing.T) {
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
req.GetBase().TargetID = -1
ret, err := node.Search(ctx, &queryPb.SearchRequest{
Req: req,
FromShardLeader: false,
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
}
func TestImpl_searchWithDmlChannel(t *testing.T) {
@ -790,6 +799,15 @@ func TestImpl_Query(t *testing.T) {
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
req.GetBase().TargetID = -1
ret, err := node.Query(ctx, &queryPb.QueryRequest{
Req: req,
FromShardLeader: false,
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
}
func TestImpl_queryWithDmlChannel(t *testing.T) {

View File

@ -956,8 +956,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest,
// dispatch request to followers
for nodeID, segments := range segAllocs {
internalReq := typeutil.Clone(req.GetReq())
internalReq.GetBase().TargetID = nodeID
nodeReq := &querypb.SearchRequest{
Req: req.Req,
Req: internalReq,
DmlChannels: req.DmlChannels,
FromShardLeader: true,
Scope: querypb.DataScope_Historical,
@ -1041,8 +1043,10 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi
// dispatch request to followers
for nodeID, segments := range segAllocs {
internalReq := typeutil.Clone(req.GetReq())
internalReq.GetBase().TargetID = nodeID
nodeReq := &querypb.QueryRequest{
Req: req.Req,
Req: internalReq,
FromShardLeader: true,
SegmentIDs: segments,
Scope: querypb.DataScope_Historical,

View File

@ -1164,6 +1164,9 @@ func TestShardCluster_Search(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
result, err := sc.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, streamingDoNothing)
assert.NoError(t, err)
@ -1215,6 +1218,9 @@ func TestShardCluster_Search(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, func(ctx context.Context) error { return errors.New("mocked") })
assert.Error(t, err)
@ -1273,6 +1279,9 @@ func TestShardCluster_Search(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, streamingDoNothing)
assert.Error(t, err)
@ -1325,6 +1334,9 @@ func TestShardCluster_Search(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, streamingDoNothing)
assert.Error(t, err)
@ -1385,6 +1397,9 @@ func TestShardCluster_Query(t *testing.T) {
require.EqualValues(t, unavailable, sc.state.Load())
_, err := sc.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, streamingDoNothing)
assert.Error(t, err)
@ -1398,6 +1413,9 @@ func TestShardCluster_Query(t *testing.T) {
sc.SetupFirstVersion()
_, err := sc.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName + "_suffix"},
}, streamingDoNothing)
assert.Error(t, err)
@ -1447,6 +1465,9 @@ func TestShardCluster_Query(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
result, err := sc.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, streamingDoNothing)
assert.NoError(t, err)
@ -1497,6 +1518,9 @@ func TestShardCluster_Query(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, func(ctx context.Context) error { return errors.New("mocked") })
assert.Error(t, err)
@ -1555,6 +1579,9 @@ func TestShardCluster_Query(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, streamingDoNothing)
assert.Error(t, err)
@ -1608,6 +1635,9 @@ func TestShardCluster_Query(t *testing.T) {
require.EqualValues(t, available, sc.state.Load())
_, err := sc.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{},
},
DmlChannels: []string{vchannelName},
}, streamingDoNothing)
assert.Error(t, err)