mirror of https://github.com/milvus-io/milvus.git
fix: [2.4]not enable rate limiter for restful v1 (#39554)
issue: #39556 pr: #39553 2.5: #39555 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/39575/head
parent
9848606a82
commit
6eadcacd94
|
@ -517,7 +517,7 @@ func (h *HandlersV1) query(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) {
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -595,7 +595,7 @@ func (h *HandlersV1) get(c *gin.Context) {
|
|||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
queryReq := req.(*milvuspb.QueryRequest)
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -675,7 +675,7 @@ func (h *HandlersV1) delete(c *gin.Context) {
|
|||
}
|
||||
deleteReq.Expr = filter
|
||||
}
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -758,7 +758,7 @@ func (h *HandlersV1) insert(c *gin.Context) {
|
|||
})
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -864,7 +864,7 @@ func (h *HandlersV1) upsert(c *gin.Context) {
|
|||
})
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
@ -967,7 +967,7 @@ func (h *HandlersV1) search(c *gin.Context) {
|
|||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) {
|
||||
if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil {
|
||||
if _, err := CheckLimiter(ctx, req, h.proxy); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error() + ", error: " + err.Error(),
|
||||
|
|
|
@ -1288,7 +1288,12 @@ func CheckLimiter(ctx context.Context, req interface{}, pxy types.ProxyComponent
|
|||
return nil, err
|
||||
}
|
||||
|
||||
dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, req)
|
||||
request, ok := req.(proto.Message)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("wrong req format when check limiter")
|
||||
}
|
||||
|
||||
dbID, collectionIDToPartIDs, rt, n, err := proxy.GetRequestInfo(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
@ -38,7 +39,11 @@ import (
|
|||
// RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting.
|
||||
func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, req)
|
||||
request, ok := req.(proto.Message)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("wrong req format when check limiter")
|
||||
}
|
||||
dbID, collectionIDToPartIDs, rt, n, err := GetRequestInfo(ctx, request)
|
||||
if err != nil {
|
||||
log.Warn("failed to get request info", zap.Error(err))
|
||||
return handler(ctx, req)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -1911,7 +1912,7 @@ func GetRequestLabelFromContext(ctx context.Context) bool {
|
|||
}
|
||||
|
||||
// GetRequestInfo returns collection name and rateType of request and return tokens needed.
|
||||
func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
||||
func GetRequestInfo(ctx context.Context, req proto.Message) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
||||
switch r := req.(type) {
|
||||
case *milvuspb.InsertRequest:
|
||||
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
||||
|
@ -1989,6 +1990,7 @@ func GetRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
|
|||
if req == nil {
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
|
||||
}
|
||||
log.RatedWarn(60, "not supported request type for rate limiter", zap.String("type", reflect.TypeOf(req).String()))
|
||||
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue