Support update (#20875)

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
pull/21267/head
smellthemoon 2023-01-04 17:21:36 +08:00 committed by GitHub
parent ff2a68e65a
commit bf3c02155a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1984 additions and 32 deletions

View File

@ -107,7 +107,7 @@ get_deleted_bitmap(int64_t del_barrier,
// insert after delete with same pk, delete will not task effect on this insert record
// and reset bitmap to 0
if (insert_record.timestamps_[insert_row_offset] > delete_timestamp) {
if (insert_record.timestamps_[insert_row_offset] >= delete_timestamp) {
bitmap->reset(insert_row_offset);
continue;
}

View File

@ -791,7 +791,7 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnSealedSegment) {
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, retrieve_result.proto_size);
ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 3);
ASSERT_EQ(query_result->ids().int_id().data().size(), 4);
DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result);

View File

@ -87,7 +87,7 @@ TEST(Util, GetDeleteBitmap) {
del_barrier = get_barrier(delete_record, query_timestamp);
res_bitmap = get_deleted_bitmap(del_barrier, insert_barrier, delete_record, insert_record, query_timestamp);
ASSERT_EQ(res_bitmap->bitmap_ptr->count(), N);
ASSERT_EQ(res_bitmap->bitmap_ptr->count(), N - 1);
// test case insert repeated pk1 (ts = {1 ... N}) -> delete pk1 (ts = N) -> query (ts = N/2)
query_timestamp = tss[N - 1] / 2;

View File

@ -685,7 +685,7 @@ func (s *Server) Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*
}
func (s *Server) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
panic("TODO: not implement")
return s.proxy.Upsert(ctx, request)
}
func (s *Server) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
@ -708,12 +708,12 @@ func (s *Server) GetDdChannel(ctx context.Context, request *internalpb.GetDdChan
return s.proxy.GetDdChannel(ctx, request)
}
//GetPersistentSegmentInfo notifies Proxy to get persistent segment info.
// GetPersistentSegmentInfo notifies Proxy to get persistent segment info.
func (s *Server) GetPersistentSegmentInfo(ctx context.Context, request *milvuspb.GetPersistentSegmentInfoRequest) (*milvuspb.GetPersistentSegmentInfoResponse, error) {
return s.proxy.GetPersistentSegmentInfo(ctx, request)
}
//GetQuerySegmentInfo notifies Proxy to get query segment info.
// GetQuerySegmentInfo notifies Proxy to get query segment info.
func (s *Server) GetQuerySegmentInfo(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error) {
return s.proxy.GetQuerySegmentInfo(ctx, request)

View File

@ -60,7 +60,7 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockBase struct {
mock.Mock
isMockGetComponentStatesOn bool
@ -102,7 +102,7 @@ func (m *MockBase) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringRe
return nil, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockRootCoord struct {
MockBase
initErr error
@ -289,7 +289,7 @@ func (m *MockRootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHeal
}, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockIndexCoord struct {
MockBase
initErr error
@ -354,7 +354,7 @@ func (m *MockIndexCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHea
return nil, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockQueryCoord struct {
MockBase
initErr error
@ -469,7 +469,7 @@ func (m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHea
}, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockDataCoord struct {
MockBase
err error
@ -627,7 +627,7 @@ func (m *MockDataCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHeal
return nil, nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
type MockProxy struct {
MockBase
err error
@ -758,6 +758,10 @@ func (m *MockProxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
return nil, nil
}
func (m *MockProxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
return nil, nil
}
func (m *MockProxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
return nil, nil
}
@ -1063,7 +1067,7 @@ func runAndWaitForServerReady(server *Server) error {
return nil
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func Test_NewServer(t *testing.T) {
paramtable.Init()
ctx := context.Background()
@ -1212,6 +1216,11 @@ func Test_NewServer(t *testing.T) {
assert.Nil(t, err)
})
t.Run("Upsert", func(t *testing.T) {
_, err := server.Upsert(ctx, nil)
assert.Nil(t, err)
})
t.Run("Search", func(t *testing.T) {
_, err := server.Search(ctx, nil)
assert.Nil(t, err)

View File

@ -35,6 +35,7 @@ const (
InsertLabel = "insert"
DeleteLabel = "delete"
UpsertLabel = "upsert"
SearchLabel = "search"
QueryLabel = "query"
CacheHitLabel = "hit"

View File

@ -380,6 +380,12 @@ func (dt *DeleteMsg) CheckAligned() error {
return nil
}
// ///////////////////////////////////////Upsert//////////////////////////////////////////
type UpsertMsg struct {
InsertMsg *InsertMsg
DeleteMsg *DeleteMsg
}
/////////////////////////////////////////TimeTick//////////////////////////////////////////
// TimeTickMsg is a message pack that contains time tick only

View File

@ -28,6 +28,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
@ -2018,11 +2019,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
}
method := "Insert"
tr := timerecord.NewTimeRecorder(method)
receiveSize := proto.Size(request)
rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(receiveSize))
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Add(float64(receiveSize))
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Add(float64(proto.Size(request)))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
it := &insertTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
@ -2119,6 +2118,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
// InsertCnt always equals to the number of entities in the request
it.result.InsertCnt = int64(request.NumRows)
receiveSize := proto.Size(it.insertMsg)
rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(receiveSize))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
successCnt := it.result.InsertCnt - int64(len(it.result.ErrIndex))
@ -2128,10 +2130,6 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
return it.result, nil
}
func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
panic("TODO: not implement")
}
// Delete delete records from collection, then these records cannot be searched.
func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) {
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Delete")
@ -2140,9 +2138,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
log.Debug("Start processing delete request in Proxy")
defer log.Debug("Finish processing delete request in Proxy")
receiveSize := proto.Size(request)
rateCol.Add(internalpb.RateType_DMLDelete.String(), float64(receiveSize))
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Add(float64(receiveSize))
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Add(float64(proto.Size(request)))
if !node.checkHealthy() {
return &milvuspb.MutationResult{
@ -2219,6 +2215,9 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
}, nil
}
receiveSize := proto.Size(dt.deleteMsg)
rateCol.Add(internalpb.RateType_DMLDelete.String(), float64(receiveSize))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
@ -2226,6 +2225,140 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
return dt.result, nil
}
// Upsert upsert records into collection.
func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Upsert")
defer sp.Finish()
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.String("partition", request.PartitionName),
zap.Uint32("NumRows", request.NumRows),
)
log.Debug("Start processing upsert request in Proxy")
if !node.checkHealthy() {
return &milvuspb.MutationResult{
Status: unhealthyStatus(),
}, nil
}
method := "Upsert"
tr := timerecord.NewTimeRecorder(method)
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Add(float64(proto.Size(request)))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
it := &upsertTask{
baseMsg: msgstream.BaseMsg{
HashValues: request.HashKeys,
},
ctx: ctx,
Condition: NewTaskCondition(ctx),
req: &milvuspb.UpsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType(commonpb.MsgType_Upsert)),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionName: request.CollectionName,
PartitionName: request.PartitionName,
FieldsData: request.FieldsData,
NumRows: request.NumRows,
},
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
IDs: &schemapb.IDs{
IdField: nil,
},
},
idAllocator: node.rowIDAllocator,
segIDAssigner: node.segAssigner,
chMgr: node.chMgr,
chTicker: node.chTicker,
}
if len(it.req.PartitionName) <= 0 {
it.req.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue()
}
constructFailedResponse := func(err error, errCode commonpb.ErrorCode) *milvuspb.MutationResult {
numRows := request.NumRows
errIndex := make([]uint32, numRows)
for i := uint32(0); i < numRows; i++ {
errIndex[i] = i
}
return &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: errCode,
Reason: err.Error(),
},
ErrIndex: errIndex,
}
}
log.Debug("Enqueue upsert request in Proxy",
zap.Int("len(FieldsData)", len(request.FieldsData)),
zap.Int("len(HashKeys)", len(request.HashKeys)))
if err := node.sched.dmQueue.Enqueue(it); err != nil {
log.Info("Failed to enqueue upsert task",
zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.AbandonLabel).Inc()
return &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("Detail of upsert request in Proxy",
zap.Uint64("BeginTS", it.BeginTs()),
zap.Uint64("EndTS", it.EndTs()))
if err := it.WaitToFinish(); err != nil {
log.Info("Failed to execute insert task in task scheduler",
zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.FailLabel).Inc()
return constructFailedResponse(err, it.result.Status.ErrorCode), nil
}
if it.result.Status.ErrorCode != commonpb.ErrorCode_Success {
setErrorIndex := func() {
numRows := request.NumRows
errIndex := make([]uint32, numRows)
for i := uint32(0); i < numRows; i++ {
errIndex[i] = i
}
it.result.ErrIndex = errIndex
}
setErrorIndex()
}
insertReceiveSize := proto.Size(it.upsertMsg.InsertMsg)
deleteReceiveSize := proto.Size(it.upsertMsg.DeleteMsg)
rateCol.Add(internalpb.RateType_DMLDelete.String(), float64(deleteReceiveSize))
rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(insertReceiveSize))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.SuccessLabel).Inc()
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyCollectionMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("Finish processing upsert request in Proxy")
return it.result, nil
}
// Search search the most similar records of requests.
func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
receiveSize := proto.Size(request)

View File

@ -691,6 +691,20 @@ func TestProxy(t *testing.T) {
}
}
constructCollectionUpsertRequest := func() *milvuspb.UpsertRequest {
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
return &milvuspb.UpsertRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
}
}
constructCreateIndexRequest := func() *milvuspb.CreateIndexRequest {
return &milvuspb.CreateIndexRequest{
Base: nil,
@ -1085,7 +1099,7 @@ func TestProxy(t *testing.T) {
assert.Equal(t, int64(rowNum), resp.InsertCnt)
})
// TODO(dragondriver): proxy.Delete()
//TODO(dragondriver): proxy.Delete()
flushed := true
wg.Add(1)
@ -2155,6 +2169,19 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
})
wg.Add(1)
t.Run("upsert when autoID == true", func(t *testing.T) {
defer wg.Done()
req := constructCollectionUpsertRequest()
resp, err := proxy.Upsert(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UpsertAutoIDTrue, resp.Status.ErrorCode)
assert.Equal(t, 0, len(resp.SuccIndex))
assert.Equal(t, rowNum, len(resp.ErrIndex))
assert.Equal(t, int64(0), resp.UpsertCnt)
})
wg.Add(1)
t.Run("drop collection", func(t *testing.T) {
defer wg.Done()
@ -2607,6 +2634,14 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("Upsert fail, unhealthy", func(t *testing.T) {
defer wg.Done()
resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("Search fail, unhealthy", func(t *testing.T) {
defer wg.Done()
@ -2985,6 +3020,14 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("Upsert fail, dm queue full", func(t *testing.T) {
defer wg.Done()
resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
proxy.sched.dmQueue.setMaxTaskNum(dmParallelism)
dqParallelism := proxy.sched.dqQueue.getMaxTaskNum()
@ -3220,6 +3263,14 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("Update fail, timeout", func(t *testing.T) {
defer wg.Done()
resp, err := proxy.Upsert(shortCtx, &milvuspb.UpsertRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("Search fail, timeout", func(t *testing.T) {
defer wg.Done()
@ -3295,6 +3346,161 @@ func TestProxy(t *testing.T) {
testProxyRoleTimeout(shortCtx, t, proxy)
testProxyPrivilegeTimeout(shortCtx, t, proxy)
constructCollectionSchema = func() *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
schema = constructCollectionSchema()
constructCreateCollectionRequest = func() *milvuspb.CreateCollectionRequest {
bs, err := proto.Marshal(schema)
assert.NoError(t, err)
return &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: bs,
ShardsNum: shardsNum,
}
}
createCollectionReq = constructCreateCollectionRequest()
constructPartitionReqUpsertRequestValid := func() *milvuspb.UpsertRequest {
pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
return &milvuspb.UpsertRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
}
}
constructCollectionUpsertRequestValid := func() *milvuspb.UpsertRequest {
pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
hashKeys := generateHashKeys(rowNum)
return &milvuspb.UpsertRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
}
}
wg.Add(1)
t.Run("create collection upsert valid", func(t *testing.T) {
defer wg.Done()
req := createCollectionReq
resp, err := proxy.CreateCollection(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
reqInvalidField := constructCreateCollectionRequest()
schema := constructCollectionSchema()
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
Name: "StringField",
DataType: schemapb.DataType_String,
})
bs, err := proto.Marshal(schema)
assert.NoError(t, err)
reqInvalidField.CollectionName = "invalid_field_coll_upsert_valid"
reqInvalidField.Schema = bs
resp, err = proxy.CreateCollection(ctx, reqInvalidField)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
wg.Add(1)
t.Run("create partition", func(t *testing.T) {
defer wg.Done()
resp, err := proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
// create partition with non-exist collection -> fail
resp, err = proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: nil,
DbName: dbName,
CollectionName: otherCollectionName,
PartitionName: partitionName,
})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
wg.Add(1)
t.Run("upsert partition", func(t *testing.T) {
defer wg.Done()
req := constructPartitionReqUpsertRequestValid()
resp, err := proxy.Upsert(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, rowNum, len(resp.SuccIndex))
assert.Equal(t, 0, len(resp.ErrIndex))
assert.Equal(t, int64(rowNum), resp.UpsertCnt)
})
wg.Add(1)
t.Run("upsert when autoID == false", func(t *testing.T) {
defer wg.Done()
req := constructCollectionUpsertRequestValid()
resp, err := proxy.Upsert(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.Equal(t, rowNum, len(resp.SuccIndex))
assert.Equal(t, 0, len(resp.ErrIndex))
assert.Equal(t, int64(rowNum), resp.UpsertCnt)
})
testServer.gracefulStop()
wg.Wait()

View File

@ -65,11 +65,12 @@ const (
ReleaseCollectionTaskName = "ReleaseCollectionTask"
LoadPartitionTaskName = "LoadPartitionsTask"
ReleasePartitionTaskName = "ReleasePartitionsTask"
deleteTaskName = "DeleteTask"
DeleteTaskName = "DeleteTask"
CreateAliasTaskName = "CreateAliasTask"
DropAliasTaskName = "DropAliasTask"
AlterAliasTaskName = "AlterAliasTask"
AlterCollectionTaskName = "AlterCollectionTask"
UpsertTaskName = "UpsertTask"
// minFloat32 minimum float.
minFloat32 = -1 * float32(math.MaxFloat32)

View File

@ -58,7 +58,7 @@ func (dt *deleteTask) Type() commonpb.MsgType {
}
func (dt *deleteTask) Name() string {
return deleteTaskName
return DeleteTaskName
}
func (dt *deleteTask) BeginTs() Timestamp {

View File

@ -126,12 +126,12 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return err
}
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
if err != nil {
log.Error("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
it.schema = collSchema
it.schema = schema
rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs
@ -172,7 +172,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
// set field ID to insert field data
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), collSchema)
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), schema)
if err != nil {
log.Error("set fieldID to fieldData failed",
zap.Error(err))

View File

@ -1698,6 +1698,94 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
assert.NoError(t, task.PostExecute(ctx))
})
t.Run("upsert", func(t *testing.T) {
hash := generateHashKeys(nb)
task := &upsertTask{
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
NumRows: uint64(nb),
Version: internalpb.InsertDataVersion_ColumnBased,
},
},
DeleteMsg: &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
},
},
},
Condition: NewTaskCondition(ctx),
req: &milvuspb.UpsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
HashKeys: hash,
NumRows: uint32(nb),
},
ctx: ctx,
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
IDs: nil,
SuccIndex: nil,
ErrIndex: nil,
Acknowledged: false,
InsertCnt: 0,
DeleteCnt: 0,
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
task.req.FieldsData = append(task.req.FieldsData, generateFieldData(dataType, fieldName, nb))
fieldID++
}
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
})
t.Run("delete", func(t *testing.T) {
task := &deleteTask{
Condition: NewTaskCondition(ctx),

View File

@ -0,0 +1,506 @@
// // Licensed to the LF AI & Data foundation under one
// // or more contributor license agreements. See the NOTICE file
// // distributed with this work for additional information
// // regarding copyright ownership. The ASF licenses this file
// // to you under the Apache License, Version 2.0 (the
// // "License"); you may not use this file except in compliance
// // with the License. You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
package proxy
import (
"context"
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
type upsertTask struct {
Condition
upsertMsg *msgstream.UpsertMsg
req *milvuspb.UpsertRequest
baseMsg msgstream.BaseMsg
ctx context.Context
timestamps []uint64
rowIDs []int64
result *milvuspb.MutationResult
idAllocator *allocator.IDAllocator
segIDAssigner *segIDAssigner
collectionID UniqueID
chMgr channelsMgr
chTicker channelsTimeTicker
vChannels []vChan
pChannels []pChan
schema *schemapb.CollectionSchema
}
// TraceCtx returns upsertTask context
func (it *upsertTask) TraceCtx() context.Context {
return it.ctx
}
func (it *upsertTask) ID() UniqueID {
return it.req.Base.MsgID
}
func (it *upsertTask) SetID(uid UniqueID) {
it.req.Base.MsgID = uid
}
func (it *upsertTask) Name() string {
return UpsertTaskName
}
func (it *upsertTask) Type() commonpb.MsgType {
return it.req.Base.MsgType
}
func (it *upsertTask) BeginTs() Timestamp {
return it.baseMsg.BeginTimestamp
}
func (it *upsertTask) SetTs(ts Timestamp) {
it.baseMsg.BeginTimestamp = ts
it.baseMsg.EndTimestamp = ts
}
func (it *upsertTask) EndTs() Timestamp {
return it.baseMsg.EndTimestamp
}
func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
ret := make(map[pChan]pChanStatistics)
channels, err := it.getChannels()
if err != nil {
return ret, err
}
beginTs := it.BeginTs()
endTs := it.EndTs()
for _, channel := range channels {
ret[channel] = pChanStatistics{
minTs: beginTs,
maxTs: endTs,
}
}
return ret, nil
}
func (it *upsertTask) getChannels() ([]pChan, error) {
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.CollectionName)
if err != nil {
return nil, err
}
return it.chMgr.getChannels(collID)
}
func (it *upsertTask) OnEnqueue() error {
return nil
}
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
collectionName := it.upsertMsg.InsertMsg.CollectionName
if err := validateCollectionName(collectionName); err != nil {
log.Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
partitionTag := it.upsertMsg.InsertMsg.PartitionName
if err := validatePartitionTag(partitionTag, true); err != nil {
log.Error("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err))
return err
}
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
// set upsertTask.insertRequest.rowIDs
tr := timerecord.NewTimeRecorder("applyPK")
rowIDBegin, rowIDEnd, _ := it.idAllocator.Alloc(rowNums)
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan()))
it.upsertMsg.InsertMsg.RowIDs = make([]UniqueID, rowNums)
it.rowIDs = make([]UniqueID, rowNums)
for i := rowIDBegin; i < rowIDEnd; i++ {
offset := i - rowIDBegin
it.upsertMsg.InsertMsg.RowIDs[offset] = i
it.rowIDs[offset] = i
}
// set upsertTask.insertRequest.timeStamps
rowNum := it.upsertMsg.InsertMsg.NRows()
it.upsertMsg.InsertMsg.Timestamps = make([]uint64, rowNum)
it.timestamps = make([]uint64, rowNum)
for index := range it.timestamps {
it.upsertMsg.InsertMsg.Timestamps[index] = it.BeginTs()
it.timestamps[index] = it.BeginTs()
}
// set result.SuccIndex
sliceIndex := make([]uint32, rowNums)
for i := uint32(0); i < rowNums; i++ {
sliceIndex[i] = i
}
it.result.SuccIndex = sliceIndex
// check primaryFieldData whether autoID is true or not
// only allow support autoID == false
var err error
it.result.IDs, err = upsertCheckPrimaryFieldData(it.schema, it.result, it.upsertMsg.InsertMsg)
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
if err != nil {
log.Error("check primary field data and hash primary key failed when upsert",
zap.Error(err))
return err
}
// set field ID to insert field data
err = fillFieldIDBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
if err != nil {
log.Error("insert set fieldID to fieldData failed when upsert",
zap.Error(err))
return err
}
// check that all field's number rows are equal
if err = it.upsertMsg.InsertMsg.CheckAligned(); err != nil {
log.Error("field data is not aligned when upsert",
zap.Error(err))
return err
}
log.Debug("Proxy Upsert insertPreExecute done")
return nil
}
func (it *upsertTask) deletePreExecute(ctx context.Context) error {
collName := it.upsertMsg.DeleteMsg.CollectionName
log := log.Ctx(ctx).With(
zap.String("collectionName", collName))
if err := validateCollectionName(collName); err != nil {
log.Info("Invalid collection name", zap.Error(err))
return err
}
collID, err := globalMetaCache.GetCollectionID(ctx, collName)
if err != nil {
log.Info("Failed to get collection id", zap.Error(err))
return err
}
it.upsertMsg.DeleteMsg.CollectionID = collID
it.collectionID = collID
// If partitionName is not empty, partitionID will be set.
if len(it.upsertMsg.DeleteMsg.PartitionName) > 0 {
partName := it.upsertMsg.DeleteMsg.PartitionName
if err := validatePartitionTag(partName, true); err != nil {
log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
return err
}
partID, err := globalMetaCache.GetPartitionID(ctx, collName, partName)
if err != nil {
log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
return err
}
it.upsertMsg.DeleteMsg.PartitionID = partID
} else {
it.upsertMsg.DeleteMsg.PartitionID = common.InvalidPartitionID
}
it.upsertMsg.DeleteMsg.Timestamps = make([]uint64, it.upsertMsg.DeleteMsg.NumRows)
for index := range it.upsertMsg.DeleteMsg.Timestamps {
it.upsertMsg.DeleteMsg.Timestamps[index] = it.BeginTs()
}
log.Debug("Proxy Upsert deletePreExecute done")
return nil
}
func (it *upsertTask) PreExecute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Upsert-PreExecute")
defer sp.Finish()
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
it.req.Base = commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType(commonpb.MsgType_Upsert)),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
)
it.result = &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
IDs: &schemapb.IDs{
IdField: nil,
},
Timestamp: it.EndTs(),
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.CollectionName)
if err != nil {
log.Info("Failed to get collection schema", zap.Error(err))
return err
}
it.schema = schema
it.upsertMsg = &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: internalpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionName: it.req.CollectionName,
PartitionName: it.req.PartitionName,
FieldsData: it.req.FieldsData,
NumRows: uint64(it.req.NumRows),
Version: internalpb.InsertDataVersion_ColumnBased,
},
},
DeleteMsg: &msgstream.DeleteMsg{
DeleteRequest: internalpb.DeleteRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
DbName: it.req.DbName,
CollectionName: it.req.CollectionName,
NumRows: int64(it.req.NumRows),
PartitionName: it.req.PartitionName,
CollectionID: it.collectionID,
},
},
}
err = it.insertPreExecute(ctx)
if err != nil {
log.Info("Fail to insertPreExecute", zap.Error(err))
return err
}
err = it.deletePreExecute(ctx)
if err != nil {
log.Info("Fail to deletePreExecute", zap.Error(err))
return err
}
it.result.DeleteCnt = it.upsertMsg.DeleteMsg.NumRows
it.result.InsertCnt = int64(it.upsertMsg.InsertMsg.NumRows)
if it.result.DeleteCnt != it.result.InsertCnt {
log.Error("DeleteCnt and InsertCnt are not the same when upsert",
zap.Int64("DeleteCnt", it.result.DeleteCnt),
zap.Int64("InsertCnt", it.result.InsertCnt))
}
it.result.UpsertCnt = it.result.InsertCnt
log.Debug("Proxy Upsert PreExecute done")
return nil
}
func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgPack) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", it.ID()))
defer tr.Elapse("insert execute done when insertExecute")
collectionName := it.upsertMsg.InsertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
if err != nil {
return err
}
it.upsertMsg.InsertMsg.CollectionID = collID
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collID))
var partitionID UniqueID
if len(it.upsertMsg.InsertMsg.PartitionName) > 0 {
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.req.PartitionName)
if err != nil {
return err
}
} else {
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, Params.CommonCfg.DefaultPartitionName.GetValue())
if err != nil {
return err
}
}
it.upsertMsg.InsertMsg.PartitionID = partitionID
tr.Record("get collection id & partition id from cache when insertExecute")
_, err = it.chMgr.getOrCreateDmlStream(collID)
if err != nil {
return err
}
tr.Record("get used message stream when insertExecute")
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Error("get vChannels failed when insertExecute",
zap.Error(err))
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
return err
}
log.Debug("send insert request to virtual channels when insertExecute",
zap.String("collection", it.req.GetCollectionName()),
zap.String("partition", it.req.GetPartitionName()),
zap.Int64("collection_id", collID),
zap.Int64("partition_id", partitionID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", it.ID()))
// assign segmentID for insert data and repack data by segmentID
insertMsgPack, err := assignSegmentID(it.TraceCtx(), it.upsertMsg.InsertMsg, it.result, channelNames, it.idAllocator, it.segIDAssigner)
if err != nil {
log.Error("assign segmentID and repack insert data failed when insertExecute",
zap.Error(err))
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
return err
}
log.Debug("assign segmentID for insert data success when insertExecute",
zap.String("collectionName", it.req.CollectionName))
tr.Record("assign segment id")
msgPack.Msgs = append(msgPack.Msgs, insertMsgPack.Msgs...)
log.Debug("Proxy Insert Execute done when upsert",
zap.String("collectionName", collectionName))
return nil
}
func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgPack) (err error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
defer tr.Elapse("delete execute done when upsert")
collID := it.upsertMsg.DeleteMsg.CollectionID
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collID))
// hash primary keys to channels
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Warn("get vChannels failed when deleteExecute", zap.Error(err))
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
return err
}
it.upsertMsg.DeleteMsg.PrimaryKeys = it.result.IDs
it.upsertMsg.DeleteMsg.HashValues = typeutil.HashPK2Channels(it.upsertMsg.DeleteMsg.PrimaryKeys, channelNames)
log.Debug("send delete request to virtual channels when deleteExecute",
zap.Int64("collection_id", collID),
zap.Strings("virtual_channels", channelNames))
tr.Record("get vchannels")
// repack delete msg by dmChannel
result := make(map[uint32]msgstream.TsMsg)
collectionName := it.upsertMsg.DeleteMsg.CollectionName
collectionID := it.upsertMsg.DeleteMsg.CollectionID
partitionID := it.upsertMsg.DeleteMsg.PartitionID
partitionName := it.upsertMsg.DeleteMsg.PartitionName
proxyID := it.upsertMsg.DeleteMsg.Base.SourceID
for index, key := range it.upsertMsg.DeleteMsg.HashValues {
ts := it.upsertMsg.DeleteMsg.Timestamps[index]
_, ok := result[key]
if !ok {
sliceRequest := internalpb.DeleteRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(proxyID),
),
CollectionID: collectionID,
PartitionID: partitionID,
CollectionName: collectionName,
PartitionName: partitionName,
PrimaryKeys: &schemapb.IDs{},
}
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
},
DeleteRequest: sliceRequest,
}
result[key] = deleteMsg
}
curMsg := result[key].(*msgstream.DeleteMsg)
curMsg.HashValues = append(curMsg.HashValues, it.upsertMsg.DeleteMsg.HashValues[index])
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
curMsg.NumRows++
}
// send delete request to log broker
deleteMsgPack := &msgstream.MsgPack{
BeginTs: it.upsertMsg.DeleteMsg.BeginTs(),
EndTs: it.upsertMsg.DeleteMsg.EndTs(),
}
for _, msg := range result {
if msg != nil {
deleteMsgPack.Msgs = append(deleteMsgPack.Msgs, msg)
}
}
msgPack.Msgs = append(msgPack.Msgs, deleteMsgPack.Msgs...)
log.Debug("Proxy Upsert deleteExecute done")
return nil
}
func (it *upsertTask) Execute(ctx context.Context) (err error) {
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Upsert-Execute")
defer sp.Finish()
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
stream, err := it.chMgr.getOrCreateDmlStream(it.collectionID)
if err != nil {
return err
}
msgPack := &msgstream.MsgPack{
BeginTs: it.BeginTs(),
EndTs: it.EndTs(),
}
err = it.insertExecute(ctx, msgPack)
if err != nil {
log.Info("Fail to insertExecute", zap.Error(err))
return err
}
err = it.deleteExecute(ctx, msgPack)
if err != nil {
log.Info("Fail to deleteExecute", zap.Error(err))
return err
}
tr.Record("pack messages in upsert")
err = stream.Produce(msgPack)
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
it.result.Status.Reason = err.Error()
return err
}
sendMsgDur := tr.Record("send upsert request to dml channels")
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
log.Debug("Proxy Upsert Execute done")
return nil
}
func (it *upsertTask) PostExecute(ctx context.Context) error {
return nil
}

View File

@ -0,0 +1,684 @@
// // Licensed to the LF AI & Data foundation under one
// // or more contributor license agreements. See the NOTICE file
// // distributed with this work for additional information
// // regarding copyright ownership. The ASF licenses this file
// // to you under the Apache License, Version 2.0 (the
// // "License"); you may not use this file except in compliance
// // with the License. You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
package proxy
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/stretchr/testify/assert"
)
func TestUpsertTask_CheckAligned(t *testing.T) {
var err error
// passed NumRows is less than 0
case1 := upsertTask{
req: &milvuspb.UpsertRequest{
NumRows: 0,
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: internalpb.InsertRequest{},
},
},
}
case1.upsertMsg.InsertMsg.InsertRequest = internalpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
),
CollectionName: case1.req.CollectionName,
PartitionName: case1.req.PartitionName,
FieldsData: case1.req.FieldsData,
NumRows: uint64(case1.req.NumRows),
Version: internalpb.InsertDataVersion_ColumnBased,
}
err = case1.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// checkLengthOfFieldsData was already checked by TestUpsertTask_checkLengthOfFieldsData
boolFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}
int8FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}
int16FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}
int32FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}
int64FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}
floatFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float}
doubleFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Double}
floatVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
binaryVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BinaryVector}
varCharFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}
numRows := 20
dim := 128
case2 := upsertTask{
req: &milvuspb.UpsertRequest{
NumRows: uint32(numRows),
FieldsData: []*schemapb.FieldData{},
},
rowIDs: generateInt64Array(numRows),
timestamps: generateUint64Array(numRows),
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkRowNums",
Description: "TestUpsertTask_checkRowNums",
AutoID: false,
Fields: []*schemapb.FieldSchema{
boolFieldSchema,
int8FieldSchema,
int16FieldSchema,
int32FieldSchema,
int64FieldSchema,
floatFieldSchema,
doubleFieldSchema,
floatVectorFieldSchema,
binaryVectorFieldSchema,
varCharFieldSchema,
},
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: internalpb.InsertRequest{},
},
},
}
// satisfied
case2.req.FieldsData = []*schemapb.FieldData{
newScalarFieldData(boolFieldSchema, "Bool", numRows),
newScalarFieldData(int8FieldSchema, "Int8", numRows),
newScalarFieldData(int16FieldSchema, "Int16", numRows),
newScalarFieldData(int32FieldSchema, "Int32", numRows),
newScalarFieldData(int64FieldSchema, "Int64", numRows),
newScalarFieldData(floatFieldSchema, "Float", numRows),
newScalarFieldData(doubleFieldSchema, "Double", numRows),
newFloatVectorFieldData("FloatVector", numRows, dim),
newBinaryVectorFieldData("BinaryVector", numRows, dim),
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
}
case2.upsertMsg.InsertMsg.InsertRequest = internalpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
),
CollectionName: case2.req.CollectionName,
PartitionName: case2.req.PartitionName,
FieldsData: case2.req.FieldsData,
NumRows: uint64(case2.req.NumRows),
RowIDs: case2.rowIDs,
Timestamps: case2.timestamps,
Version: internalpb.InsertDataVersion_ColumnBased,
}
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less bool data
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more bool data
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less int8 data
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int8 data
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less int16 data
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int16 data
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less int32 data
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int32 data
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less int64 data
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int64 data
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less float data
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more float data
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less double data
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less float vectors
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more float vectors
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less binary vectors
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more binary vectors
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less double data
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
}
// func TestProxyUpsertValid(t *testing.T) {
// var err error
// var wg sync.WaitGroup
// paramtable.Init()
// path := "/tmp/milvus/rocksmq" + funcutil.GenRandomStr()
// t.Setenv("ROCKSMQ_PATH", path)
// defer os.RemoveAll(path)
// ctx, cancel := context.WithCancel(context.Background())
// ctx = GetContext(ctx, "root:123456")
// localMsg := true
// factory := dependency.NewDefaultFactory(localMsg)
// alias := "TestProxyUpsertValid"
// log.Info("Initialize parameter table of Proxy")
// rc := runRootCoord(ctx, localMsg)
// log.Info("running RootCoord ...")
// if rc != nil {
// defer func() {
// err := rc.Stop()
// assert.NoError(t, err)
// log.Info("stop RootCoord")
// }()
// }
// dc := runDataCoord(ctx, localMsg)
// log.Info("running DataCoord ...")
// if dc != nil {
// defer func() {
// err := dc.Stop()
// assert.NoError(t, err)
// log.Info("stop DataCoord")
// }()
// }
// dn := runDataNode(ctx, localMsg, alias)
// log.Info("running DataNode ...")
// if dn != nil {
// defer func() {
// err := dn.Stop()
// assert.NoError(t, err)
// log.Info("stop DataNode")
// }()
// }
// qc := runQueryCoord(ctx, localMsg)
// log.Info("running QueryCoord ...")
// if qc != nil {
// defer func() {
// err := qc.Stop()
// assert.NoError(t, err)
// log.Info("stop QueryCoord")
// }()
// }
// qn := runQueryNode(ctx, localMsg, alias)
// log.Info("running QueryNode ...")
// if qn != nil {
// defer func() {
// err := qn.Stop()
// assert.NoError(t, err)
// log.Info("stop query node")
// }()
// }
// ic := runIndexCoord(ctx, localMsg)
// log.Info("running IndexCoord ...")
// if ic != nil {
// defer func() {
// err := ic.Stop()
// assert.NoError(t, err)
// log.Info("stop IndexCoord")
// }()
// }
// in := runIndexNode(ctx, localMsg, alias)
// log.Info("running IndexNode ...")
// if in != nil {
// defer func() {
// err := in.Stop()
// assert.NoError(t, err)
// log.Info("stop IndexNode")
// }()
// }
// time.Sleep(10 * time.Millisecond)
// proxy, err := NewProxy(ctx, factory)
// assert.NoError(t, err)
// assert.NotNil(t, proxy)
// etcdcli, err := etcd.GetEtcdClient(
// Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
// Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
// Params.EtcdCfg.Endpoints.GetAsStrings(),
// Params.EtcdCfg.EtcdTLSCert.GetValue(),
// Params.EtcdCfg.EtcdTLSKey.GetValue(),
// Params.EtcdCfg.EtcdTLSCACert.GetValue(),
// Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
// defer etcdcli.Close()
// assert.NoError(t, err)
// proxy.SetEtcdClient(etcdcli)
// testServer := newProxyTestServer(proxy)
// wg.Add(1)
// base := paramtable.BaseTable{}
// base.Init(0)
// var p paramtable.GrpcServerConfig
// p.Init(typeutil.ProxyRole, &base)
// go testServer.startGrpc(ctx, &wg, &p)
// assert.NoError(t, testServer.waitForGrpcReady())
// rootCoordClient, err := rcc.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
// assert.NoError(t, err)
// err = rootCoordClient.Init()
// assert.NoError(t, err)
// err = funcutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration)
// assert.NoError(t, err)
// proxy.SetRootCoordClient(rootCoordClient)
// log.Info("Proxy set root coordinator client")
// dataCoordClient, err := grpcdatacoordclient2.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
// assert.NoError(t, err)
// err = dataCoordClient.Init()
// assert.NoError(t, err)
// err = funcutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration)
// assert.NoError(t, err)
// proxy.SetDataCoordClient(dataCoordClient)
// log.Info("Proxy set data coordinator client")
// queryCoordClient, err := grpcquerycoordclient.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
// assert.NoError(t, err)
// err = queryCoordClient.Init()
// assert.NoError(t, err)
// err = funcutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration)
// assert.NoError(t, err)
// proxy.SetQueryCoordClient(queryCoordClient)
// log.Info("Proxy set query coordinator client")
// indexCoordClient, err := grpcindexcoordclient.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
// assert.NoError(t, err)
// err = indexCoordClient.Init()
// assert.NoError(t, err)
// err = funcutil.WaitForComponentHealthy(ctx, indexCoordClient, typeutil.IndexCoordRole, attempts, sleepDuration)
// assert.NoError(t, err)
// proxy.SetIndexCoordClient(indexCoordClient)
// log.Info("Proxy set index coordinator client")
// proxy.UpdateStateCode(commonpb.StateCode_Initializing)
// err = proxy.Init()
// assert.NoError(t, err)
// err = proxy.Start()
// assert.NoError(t, err)
// assert.Equal(t, commonpb.StateCode_Healthy, proxy.stateCode.Load().(commonpb.StateCode))
// // register proxy
// err = proxy.Register()
// assert.NoError(t, err)
// log.Info("Register proxy done")
// defer func() {
// err := proxy.Stop()
// assert.NoError(t, err)
// }()
// prefix := "test_proxy_"
// partitionPrefix := "test_proxy_partition_"
// dbName := ""
// collectionName := prefix + funcutil.GenRandomStr()
// otherCollectionName := collectionName + "_other_" + funcutil.GenRandomStr()
// partitionName := partitionPrefix + funcutil.GenRandomStr()
// // otherPartitionName := partitionPrefix + "_other_" + funcutil.GenRandomStr()
// shardsNum := int32(2)
// int64Field := "int64"
// floatVecField := "fVec"
// dim := 128
// rowNum := 30
// // indexName := "_default"
// // nlist := 10
// // nprobe := 10
// // topk := 10
// // add a test parameter
// // roundDecimal := 6
// // nq := 10
// // expr := fmt.Sprintf("%s > 0", int64Field)
// // var segmentIDs []int64
// constructCollectionSchema := func() *schemapb.CollectionSchema {
// pk := &schemapb.FieldSchema{
// FieldID: 0,
// Name: int64Field,
// IsPrimaryKey: true,
// Description: "",
// DataType: schemapb.DataType_Int64,
// TypeParams: nil,
// IndexParams: nil,
// AutoID: false,
// }
// fVec := &schemapb.FieldSchema{
// FieldID: 0,
// Name: floatVecField,
// IsPrimaryKey: false,
// Description: "",
// DataType: schemapb.DataType_FloatVector,
// TypeParams: []*commonpb.KeyValuePair{
// {
// Key: "dim",
// Value: strconv.Itoa(dim),
// },
// },
// IndexParams: nil,
// AutoID: false,
// }
// return &schemapb.CollectionSchema{
// Name: collectionName,
// Description: "",
// AutoID: false,
// Fields: []*schemapb.FieldSchema{
// pk,
// fVec,
// },
// }
// }
// schema := constructCollectionSchema()
// constructCreateCollectionRequest := func() *milvuspb.CreateCollectionRequest {
// bs, err := proto.Marshal(schema)
// assert.NoError(t, err)
// return &milvuspb.CreateCollectionRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// Schema: bs,
// ShardsNum: shardsNum,
// }
// }
// createCollectionReq := constructCreateCollectionRequest()
// constructPartitionReqUpsertRequestValid := func() *milvuspb.UpsertRequest {
// pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
// fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
// hashKeys := generateHashKeys(rowNum)
// return &milvuspb.UpsertRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// PartitionName: partitionName,
// FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
// HashKeys: hashKeys,
// NumRows: uint32(rowNum),
// }
// }
// constructCollectionUpsertRequestValid := func() *milvuspb.UpsertRequest {
// pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
// fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
// hashKeys := generateHashKeys(rowNum)
// return &milvuspb.UpsertRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// PartitionName: partitionName,
// FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
// HashKeys: hashKeys,
// NumRows: uint32(rowNum),
// }
// }
// wg.Add(1)
// t.Run("create collection upsert valid", func(t *testing.T) {
// defer wg.Done()
// req := createCollectionReq
// resp, err := proxy.CreateCollection(ctx, req)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
// reqInvalidField := constructCreateCollectionRequest()
// schema := constructCollectionSchema()
// schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
// Name: "StringField",
// DataType: schemapb.DataType_String,
// })
// bs, err := proto.Marshal(schema)
// assert.NoError(t, err)
// reqInvalidField.CollectionName = "invalid_field_coll_upsert_valid"
// reqInvalidField.Schema = bs
// resp, err = proxy.CreateCollection(ctx, reqInvalidField)
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
// })
// wg.Add(1)
// t.Run("create partition", func(t *testing.T) {
// defer wg.Done()
// resp, err := proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// PartitionName: partitionName,
// })
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
// // create partition with non-exist collection -> fail
// resp, err = proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: otherCollectionName,
// PartitionName: partitionName,
// })
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
// })
// wg.Add(1)
// t.Run("upsert partition", func(t *testing.T) {
// defer wg.Done()
// req := constructPartitionReqUpsertRequestValid()
// resp, err := proxy.Upsert(ctx, req)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// assert.Equal(t, rowNum, len(resp.SuccIndex))
// assert.Equal(t, 0, len(resp.ErrIndex))
// assert.Equal(t, int64(rowNum), resp.UpsertCnt)
// })
// wg.Add(1)
// t.Run("upsert when autoID == false", func(t *testing.T) {
// defer wg.Done()
// req := constructCollectionUpsertRequestValid()
// resp, err := proxy.Upsert(ctx, req)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// assert.Equal(t, rowNum, len(resp.SuccIndex))
// assert.Equal(t, 0, len(resp.ErrIndex))
// assert.Equal(t, int64(rowNum), resp.UpsertCnt)
// })
// proxy.UpdateStateCode(commonpb.StateCode_Abnormal)
// wg.Add(1)
// t.Run("Upsert fail, unhealthy", func(t *testing.T) {
// defer wg.Done()
// resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// })
// dmParallelism := proxy.sched.dmQueue.getMaxTaskNum()
// proxy.sched.dmQueue.setMaxTaskNum(0)
// wg.Add(1)
// t.Run("Upsert fail, dm queue full", func(t *testing.T) {
// defer wg.Done()
// resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// })
// proxy.sched.dmQueue.setMaxTaskNum(dmParallelism)
// timeout := time.Nanosecond
// shortCtx, shortCancel := context.WithTimeout(ctx, timeout)
// defer shortCancel()
// time.Sleep(timeout)
// wg.Add(1)
// t.Run("Update fail, timeout", func(t *testing.T) {
// defer wg.Done()
// resp, err := proxy.Upsert(shortCtx, &milvuspb.UpsertRequest{})
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// })
// testServer.gracefulStop()
// wg.Wait()
// cancel()
// }

View File

@ -35,6 +35,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util"
@ -944,6 +945,49 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
return ids, nil
}
// TODO(smellthemoon): can merge it with checkPrimaryFieldData
func upsertCheckPrimaryFieldData(schema *schemapb.CollectionSchema, result *milvuspb.MutationResult, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) {
rowNums := uint32(insertMsg.NRows())
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
if insertMsg.NRows() <= 0 {
return nil, errNumRowsLessThanOrEqualToZero(rowNums)
}
if err := checkLengthOfFieldsData(schema, insertMsg); err != nil {
return nil, err
}
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
log.Error("get primary field schema failed", zap.String("collectionName", insertMsg.CollectionName), zap.Any("schema", schema), zap.Error(err))
return nil, err
}
// get primaryFieldData whether autoID is true or not
var primaryFieldData *schemapb.FieldData
if primaryFieldSchema.AutoID {
// upsert has not supported when autoID == true
log.Info("can not upsert when auto id enabled",
zap.String("primaryFieldSchemaName", primaryFieldSchema.Name))
result.Status.ErrorCode = commonpb.ErrorCode_UpsertAutoIDTrue
return nil, fmt.Errorf("upsert can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name)
}
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
if err != nil {
log.Error("get primary field data failed when upsert", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
return nil, err
}
// parse primaryFieldData to result.IDs, and as returned primary keys
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
if err != nil {
log.Error("parse primary field data to IDs failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
return nil, err
}
return ids, nil
}
func getCollectionProgress(ctx context.Context, queryCoord types.QueryCoord,
msgBase *commonpb.MsgBase, collectionID int64) (int64, error) {
resp, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{

View File

@ -29,6 +29,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
@ -1226,3 +1227,264 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
assert.NotEqual(t, nil, err)
}
func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
// schema is empty, though won't happen in system
// num_rows(0) should be greater than 0
case1 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkPrimaryFieldData",
Description: "TestUpsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
DbName: "TestUpsertTask_checkPrimaryFieldData",
CollectionName: "TestUpsertTask_checkPrimaryFieldData",
PartitionName: "TestUpsertTask_checkPrimaryFieldData",
},
},
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
}
_, err := upsertCheckPrimaryFieldData(case1.schema, case1.result, case1.insertMsg)
assert.NotEqual(t, nil, err)
// the num of passed fields is less than needed
case2 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkPrimaryFieldData",
Description: "TestUpsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "int64Field",
FieldID: 1,
DataType: schemapb.DataType_Int64,
},
{
Name: "floatField",
FieldID: 2,
DataType: schemapb.DataType_Float,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "int64Field",
},
},
},
},
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
}
_, err = upsertCheckPrimaryFieldData(case2.schema, case2.result, case2.insertMsg)
assert.NotEqual(t, nil, err)
// autoID == false, no primary field schema
// primary field is not found
case3 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkPrimaryFieldData",
Description: "TestUpsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "int64Field",
DataType: schemapb.DataType_Int64,
},
{
Name: "floatField",
DataType: schemapb.DataType_Float,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{},
{},
},
},
},
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
}
_, err = upsertCheckPrimaryFieldData(case3.schema, case3.result, case3.insertMsg)
assert.NotEqual(t, nil, err)
// autoID == true, upsert don't support it
case4 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkPrimaryFieldData",
Description: "TestUpsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "int64Field",
FieldID: 1,
DataType: schemapb.DataType_Int64,
},
{
Name: "floatField",
FieldID: 2,
DataType: schemapb.DataType_Float,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "int64Field",
},
},
},
},
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
}
case4.schema.Fields[0].IsPrimaryKey = true
case4.schema.Fields[0].AutoID = true
_, err = upsertCheckPrimaryFieldData(case4.schema, case4.result, case4.insertMsg)
assert.Equal(t, commonpb.ErrorCode_UpsertAutoIDTrue, case4.result.Status.ErrorCode)
assert.NotEqual(t, nil, err)
// primary field data is nil, GetPrimaryFieldData fail
case5 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkPrimaryFieldData",
Description: "TestUpsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "int64Field",
FieldID: 1,
DataType: schemapb.DataType_Int64,
},
{
Name: "floatField",
FieldID: 2,
DataType: schemapb.DataType_Float,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{},
{},
},
},
},
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
}
case5.schema.Fields[0].IsPrimaryKey = true
case5.schema.Fields[0].AutoID = false
_, err = upsertCheckPrimaryFieldData(case5.schema, case5.result, case5.insertMsg)
assert.NotEqual(t, nil, err)
// only support DataType Int64 or VarChar as PrimaryField
case6 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkPrimaryFieldData",
Description: "TestUpsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "floatVectorField",
FieldID: 1,
DataType: schemapb.DataType_FloatVector,
},
{
Name: "floatField",
FieldID: 2,
DataType: schemapb.DataType_Float,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_FloatVector,
FieldName: "floatVectorField",
},
{
Type: schemapb.DataType_Int64,
FieldName: "floatField",
},
},
},
},
result: &milvuspb.MutationResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
}
case6.schema.Fields[0].IsPrimaryKey = true
case6.schema.Fields[0].AutoID = false
_, err = upsertCheckPrimaryFieldData(case6.schema, case6.result, case6.insertMsg)
assert.NotEqual(t, nil, err)
}

View File

@ -135,7 +135,7 @@ func (fdmNode *filterDmNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
// filterInvalidDeleteMessage would filter out invalid delete messages
func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg, loadType loadType) (*msgstream.DeleteMsg, error) {
if err := msg.CheckAligned(); err != nil {
return nil, fmt.Errorf("CheckAligned failed, err = %s", err)
return nil, fmt.Errorf("DeleteMessage CheckAligned failed, err = %s", err)
}
if len(msg.Timestamps) <= 0 {
@ -168,7 +168,7 @@ func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg
// filterInvalidInsertMessage would filter out invalid insert messages
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg, loadType loadType) (*msgstream.InsertMsg, error) {
if err := msg.CheckAligned(); err != nil {
return nil, fmt.Errorf("CheckAligned failed, err = %s", err)
return nil, fmt.Errorf("InsertMessage CheckAligned failed, err = %s", err)
}
if len(msg.Timestamps) <= 0 {

View File

@ -1145,6 +1145,18 @@ type ProxyComponent interface {
// error is always nil
Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error)
// Upsert notifies Proxy to upsert rows
//
// ctx is the context to control request deadline and cancellation
// req contains the request params, including database name(reserved), collection name, partition name(optional), fields data
//
// The `Status` in response struct `MutationResult` indicates if this operation is processed successfully or fail cause;
// the `IDs` in `MutationResult` return the id list of upserted rows.
// the `SuccIndex` in `MutationResult` return the succeed number of upserted rows.
// the `ErrIndex` in `MutationResult` return the failed number of upsert rows.
// error is always nil
Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error)
// Search notifies Proxy to do search
//
// ctx is the context to control request deadline and cancellation