mirror of https://github.com/milvus-io/milvus.git
fix: [2.5]not enable rate limiter for restful v1 (#39555)
issue: #39556 pr: #39553 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/39572/head
parent
6f7b2b4e75
commit
c8327934a6
|
@ -533,7 +533,7 @@ func (h *HandlersV1) query(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) {
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -611,7 +611,7 @@ func (h *HandlersV1) get(c *gin.Context) {
|
|||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
queryReq := req.(*milvuspb.QueryRequest)
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -691,7 +691,7 @@ func (h *HandlersV1) delete(c *gin.Context) {
|
|||
}
|
||||
deleteReq.Expr = filter
|
||||
}
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -774,7 +774,7 @@ func (h *HandlersV1) insert(c *gin.Context) {
|
|||
})
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -880,7 +880,7 @@ func (h *HandlersV1) upsert(c *gin.Context) {
|
|||
})
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -983,7 +983,7 @@ func (h *HandlersV1) search(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) {
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
|
|
@ -1458,7 +1458,12 @@ func CheckLimiter(ctx context.Context, req interface{}, pxy types.ProxyComponent
|
|||
return nil, err
|
||||
}
|
||||
|
||||
dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, req)
|
||||
request, ok := req.(proto.Message)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("wrong req format when check limiter")
|
||||
}
|
||||
|
||||
dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
@ -38,7 +39,11 @@ import (
|
|||
// RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting.
|
||||
func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, req)
|
||||
request, ok := req.(proto.Message)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("wrong req format when check limiter")
|
||||
}
|
||||
dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, request)
|
||||
if err != nil {
|
||||
log.Warn("failed to get request info", zap.Error(err))
|
||||
return handler(ctx, req)
|
||||
|
|
|
@ -19,6 +19,7 @@ package proxy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -2107,7 +2108,7 @@ func GetCostValue(status *commonpb.Status) int {
|
|||
}
|
||||
|
||||
// GetRequestInfo returns collection name and rateType of request and return tokens needed.
|
||||
func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
||||
func GetRequestInfo(ctx context.Context, req proto.Message) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
||||
switch r := req.(type) {
|
||||
case *milvuspb.InsertRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
|
@ -2185,6 +2186,7 @@ func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
|
|||
if req == nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
|
||||
}
|
||||
log.RatedWarn(60, "not supported request type for rate limiter", zap.String("type", reflect.TypeOf(req).String()))
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue