mirror of https://github.com/milvus-io/milvus.git
fix: use the default partition for the limit quota when the request partition name is empty (#38005)
- issue: #37685 Signed-off-by: SimFG <bang.fu@zilliz.com>pull/38040/head
parent
49ee46ec1d
commit
302650ae0e
|
@ -264,6 +264,7 @@ type partitionInfo struct {
|
|||
partitionID typeutil.UniqueID
|
||||
createdTimestamp uint64
|
||||
createdUtcTimestamp uint64
|
||||
isDefault bool
|
||||
}
|
||||
|
||||
func (info *collectionInfo) isCollectionCached() bool {
|
||||
|
@ -427,12 +428,14 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
|
|||
return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String())
|
||||
}
|
||||
|
||||
defaultPartitionName := Params.CommonCfg.DefaultPartitionName.GetValue()
|
||||
infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo {
|
||||
return &partitionInfo{
|
||||
name: partitions.PartitionNames[idx],
|
||||
partitionID: partitions.PartitionIDs[idx],
|
||||
createdTimestamp: partitions.CreatedTimestamps[idx],
|
||||
createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx],
|
||||
isDefault: partitions.PartitionNames[idx] == defaultPartitionName,
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -630,6 +633,14 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if partitionName == "" {
|
||||
for _, info := range partitions.partitionInfos {
|
||||
if info.isDefault {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info, ok := partitions.name2Info[partitionName]
|
||||
if !ok {
|
||||
return nil, merr.WrapErrPartitionNotFound(partitionName)
|
||||
|
|
|
@ -84,8 +84,14 @@ func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map
|
|||
return 0, nil, err
|
||||
}
|
||||
if r.GetPartitionName() == "" {
|
||||
collectionSchema, err := globalMetaCache.GetCollectionSchema(ctx, r.GetDbName(), r.GetCollectionName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if collectionSchema.IsPartitionKeyCollection() {
|
||||
return db.dbID, map[int64][]int64{collectionID: {}}, nil
|
||||
}
|
||||
}
|
||||
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
|
|
|
@ -299,6 +299,7 @@ func TestRateLimitInterceptor(t *testing.T) {
|
|||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil)
|
||||
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
limiter := limiterMock{rate: 100}
|
||||
|
@ -437,6 +438,41 @@ func TestGetInfo(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("fail to get collection schema", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil).Once()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
|
||||
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once()
|
||||
|
||||
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("partition key mode", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil).Once()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
|
||||
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{
|
||||
hasPartitionKeyField: true,
|
||||
}, nil).Once()
|
||||
|
||||
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
CollectionName: "coo",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(100), db)
|
||||
assert.NotNil(t, col2par[1])
|
||||
assert.Equal(t, 0, len(col2par[1]))
|
||||
})
|
||||
|
||||
t.Run("fail to get partition", func(t *testing.T) {
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
|
||||
dbID: 100,
|
||||
|
@ -467,11 +503,12 @@ func TestGetInfo(t *testing.T) {
|
|||
dbID: 100,
|
||||
createdTimestamp: 1,
|
||||
}, nil).Times(3)
|
||||
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil).Times(1)
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3)
|
||||
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
|
||||
name: "p1",
|
||||
partitionID: 100,
|
||||
}, nil).Twice()
|
||||
}, nil).Times(3)
|
||||
{
|
||||
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
|
||||
DbName: "foo",
|
||||
|
@ -491,7 +528,7 @@ func TestGetInfo(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(100), db)
|
||||
assert.NotNil(t, col2par[10])
|
||||
assert.Equal(t, 0, len(col2par[10]))
|
||||
assert.Equal(t, int64(100), col2par[10][0])
|
||||
}
|
||||
{
|
||||
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
|
||||
|
|
|
@ -202,7 +202,12 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
// insert to _default partition
|
||||
partitionTag := it.insertMsg.GetPartitionName()
|
||||
if len(partitionTag) <= 0 {
|
||||
partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue()
|
||||
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.insertMsg.GetDbName(), collectionName, "")
|
||||
if err != nil {
|
||||
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
partitionTag = pinfo.name
|
||||
it.insertMsg.PartitionName = partitionTag
|
||||
}
|
||||
|
||||
|
|
|
@ -3651,6 +3651,204 @@ func TestPartitionKey(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDefaultPartition(t *testing.T) {
|
||||
rc := NewRootCoordMock()
|
||||
|
||||
defer rc.Close()
|
||||
qc := getQueryCoordClient()
|
||||
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rc, qc, mgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := common.DefaultShardsNum
|
||||
prefix := "TestInsertTaskWithPartitionKey"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
|
||||
fieldName2Type := make(map[string]schemapb.DataType)
|
||||
fieldName2Type["int64_field"] = schemapb.DataType_Int64
|
||||
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
|
||||
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
|
||||
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("create collection", func(t *testing.T) {
|
||||
createCollectionTask := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
|
||||
Timestamp: Timestamp(time.Now().UnixNano()),
|
||||
},
|
||||
DbName: "",
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
err = createCollectionTask.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
err = createCollectionTask.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(collectionID)
|
||||
assert.NoError(t, err)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
interval := time.Millisecond * 10
|
||||
tso := newMockTsoAllocator()
|
||||
|
||||
ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso)
|
||||
_ = ticker.start()
|
||||
defer ticker.close()
|
||||
|
||||
idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID())
|
||||
assert.NoError(t, err)
|
||||
_ = idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
|
||||
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Init()
|
||||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
nb := 10
|
||||
fieldID := common.StartOfUserFieldID
|
||||
fieldDatas := make([]*schemapb.FieldData, 0)
|
||||
for fieldName, dataType := range fieldName2Type {
|
||||
fieldData := generateFieldData(dataType, fieldName, nb)
|
||||
fieldData.FieldId = int64(fieldID)
|
||||
fieldDatas = append(fieldDatas, generateFieldData(dataType, fieldName, nb))
|
||||
fieldID++
|
||||
}
|
||||
|
||||
t.Run("Insert", func(t *testing.T) {
|
||||
it := &insertTask{
|
||||
insertMsg: &BaseInsertTask{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
FieldsData: fieldDatas,
|
||||
NumRows: uint64(nb),
|
||||
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||
},
|
||||
},
|
||||
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ctx: ctx,
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
IDs: nil,
|
||||
SuccIndex: nil,
|
||||
ErrIndex: nil,
|
||||
Acknowledged: false,
|
||||
InsertCnt: 0,
|
||||
DeleteCnt: 0,
|
||||
UpsertCnt: 0,
|
||||
Timestamp: 0,
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
it.insertMsg.PartitionName = ""
|
||||
assert.NoError(t, it.OnEnqueue())
|
||||
assert.NoError(t, it.PreExecute(ctx))
|
||||
assert.NoError(t, it.Execute(ctx))
|
||||
assert.NoError(t, it.PostExecute(ctx))
|
||||
})
|
||||
|
||||
t.Run("Upsert", func(t *testing.T) {
|
||||
hash := testutils.GenerateHashKeys(nb)
|
||||
ut := &upsertTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
baseMsg: msgstream.BaseMsg{
|
||||
HashValues: hash,
|
||||
},
|
||||
req: &milvuspb.UpsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Upsert),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
CollectionName: collectionName,
|
||||
FieldsData: fieldDatas,
|
||||
NumRows: uint32(nb),
|
||||
},
|
||||
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
IDs: &schemapb.IDs{
|
||||
IdField: nil,
|
||||
},
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
}
|
||||
|
||||
ut.req.PartitionName = ""
|
||||
assert.NoError(t, ut.OnEnqueue())
|
||||
assert.NoError(t, ut.PreExecute(ctx))
|
||||
assert.NoError(t, ut.Execute(ctx))
|
||||
assert.NoError(t, ut.PostExecute(ctx))
|
||||
})
|
||||
|
||||
t.Run("delete", func(t *testing.T) {
|
||||
dt := &deleteTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
req: &milvuspb.DeleteRequest{
|
||||
CollectionName: collectionName,
|
||||
Expr: "int64_field in [0, 1]",
|
||||
},
|
||||
ctx: ctx,
|
||||
primaryKeys: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{0, 1}}},
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
collectionID: collectionID,
|
||||
vChannels: []string{"test-channel"},
|
||||
}
|
||||
|
||||
dt.req.PartitionName = ""
|
||||
assert.NoError(t, dt.PreExecute(ctx))
|
||||
assert.NoError(t, dt.Execute(ctx))
|
||||
assert.NoError(t, dt.PostExecute(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClusteringKey(t *testing.T) {
|
||||
rc := NewRootCoordMock()
|
||||
|
||||
|
|
|
@ -317,8 +317,12 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
|||
// insert to _default partition
|
||||
partitionTag := it.req.GetPartitionName()
|
||||
if len(partitionTag) <= 0 {
|
||||
partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue()
|
||||
it.req.PartitionName = partitionTag
|
||||
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.req.GetDbName(), collectionName, "")
|
||||
if err != nil {
|
||||
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.req.PartitionName = pinfo.name
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue