Check target node ID for query/search (#20967)

Signed-off-by: yah01 <yang.cen@zilliz.com>

Signed-off-by: yah01 <yang.cen@zilliz.com>
pull/20969/head
yah01 2022-12-04 20:05:17 +08:00 committed by GitHub
parent 8094eea2ed
commit 25a3b9ae19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 2 deletions

View File

@ -404,8 +404,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
}
func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
retrieveReq := typeutil.Clone(t.RetrieveRequest)
retrieveReq.GetBase().TargetID = nodeID
req := &querypb.QueryRequest{
Req: t.RetrieveRequest,
Req: retrieveReq,
DmlChannels: channelIDs,
Scope: querypb.DataScope_All,
}

View File

@ -477,8 +477,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
searchReq := typeutil.Clone(t.SearchRequest)
searchReq.GetBase().TargetID = nodeID
req := &querypb.SearchRequest{
Req: t.SearchRequest,
Req: searchReq,
DmlChannels: channelIDs,
Scope: querypb.DataScope_All,
}

View File

@ -710,6 +710,15 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
if req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() {
return &internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID()),
},
}, nil
}
failRet := &internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
@ -1065,6 +1074,15 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
if req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() {
return &internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID()),
},
}, nil
}
failRet := &internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,

View File

@ -686,6 +686,15 @@ func TestImpl_Search(t *testing.T) {
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
req.GetBase().TargetID = -1
ret, err := node.Search(ctx, &queryPb.SearchRequest{
Req: req,
FromShardLeader: false,
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
}
func TestImpl_searchWithDmlChannel(t *testing.T) {
@ -797,6 +806,15 @@ func TestImpl_Query(t *testing.T) {
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
req.GetBase().TargetID = -1
ret, err := node.Query(ctx, &queryPb.QueryRequest{
Req: req,
FromShardLeader: false,
DmlChannels: []string{defaultDMLChannel},
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
}
func TestImpl_queryWithDmlChannel(t *testing.T) {