mirror of https://github.com/milvus-io/milvus.git
Support update (#20875)
Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/21267/head
parent
ff2a68e65a
commit
bf3c02155a
|
@ -107,7 +107,7 @@ get_deleted_bitmap(int64_t del_barrier,
|
|||
|
||||
// insert after delete with same pk, delete will not task effect on this insert record
|
||||
// and reset bitmap to 0
|
||||
if (insert_record.timestamps_[insert_row_offset] > delete_timestamp) {
|
||||
if (insert_record.timestamps_[insert_row_offset] >= delete_timestamp) {
|
||||
bitmap->reset(insert_row_offset);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -791,7 +791,7 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnSealedSegment) {
|
|||
auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
|
||||
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, retrieve_result.proto_size);
|
||||
ASSERT_TRUE(suc);
|
||||
ASSERT_EQ(query_result->ids().int_id().data().size(), 3);
|
||||
ASSERT_EQ(query_result->ids().int_id().data().size(), 4);
|
||||
|
||||
DeleteRetrievePlan(plan.release());
|
||||
DeleteRetrieveResult(&retrieve_result);
|
||||
|
|
|
@ -87,7 +87,7 @@ TEST(Util, GetDeleteBitmap) {
|
|||
|
||||
del_barrier = get_barrier(delete_record, query_timestamp);
|
||||
res_bitmap = get_deleted_bitmap(del_barrier, insert_barrier, delete_record, insert_record, query_timestamp);
|
||||
ASSERT_EQ(res_bitmap->bitmap_ptr->count(), N);
|
||||
ASSERT_EQ(res_bitmap->bitmap_ptr->count(), N - 1);
|
||||
|
||||
// test case insert repeated pk1 (ts = {1 ... N}) -> delete pk1 (ts = N) -> query (ts = N/2)
|
||||
query_timestamp = tss[N - 1] / 2;
|
||||
|
|
|
@ -685,7 +685,7 @@ func (s *Server) Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*
|
|||
}
|
||||
|
||||
func (s *Server) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
||||
panic("TODO: not implement")
|
||||
return s.proxy.Upsert(ctx, request)
|
||||
}
|
||||
|
||||
func (s *Server) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
|
@ -708,12 +708,12 @@ func (s *Server) GetDdChannel(ctx context.Context, request *internalpb.GetDdChan
|
|||
return s.proxy.GetDdChannel(ctx, request)
|
||||
}
|
||||
|
||||
//GetPersistentSegmentInfo notifies Proxy to get persistent segment info.
|
||||
// GetPersistentSegmentInfo notifies Proxy to get persistent segment info.
|
||||
func (s *Server) GetPersistentSegmentInfo(ctx context.Context, request *milvuspb.GetPersistentSegmentInfoRequest) (*milvuspb.GetPersistentSegmentInfoResponse, error) {
|
||||
return s.proxy.GetPersistentSegmentInfo(ctx, request)
|
||||
}
|
||||
|
||||
//GetQuerySegmentInfo notifies Proxy to get query segment info.
|
||||
// GetQuerySegmentInfo notifies Proxy to get query segment info.
|
||||
func (s *Server) GetQuerySegmentInfo(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error) {
|
||||
return s.proxy.GetQuerySegmentInfo(ctx, request)
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(code)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
type MockBase struct {
|
||||
mock.Mock
|
||||
isMockGetComponentStatesOn bool
|
||||
|
@ -102,7 +102,7 @@ func (m *MockBase) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringRe
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
type MockRootCoord struct {
|
||||
MockBase
|
||||
initErr error
|
||||
|
@ -289,7 +289,7 @@ func (m *MockRootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHeal
|
|||
}, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
type MockIndexCoord struct {
|
||||
MockBase
|
||||
initErr error
|
||||
|
@ -354,7 +354,7 @@ func (m *MockIndexCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHea
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
type MockQueryCoord struct {
|
||||
MockBase
|
||||
initErr error
|
||||
|
@ -469,7 +469,7 @@ func (m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHea
|
|||
}, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
type MockDataCoord struct {
|
||||
MockBase
|
||||
err error
|
||||
|
@ -627,7 +627,7 @@ func (m *MockDataCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHeal
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
type MockProxy struct {
|
||||
MockBase
|
||||
err error
|
||||
|
@ -758,6 +758,10 @@ func (m *MockProxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -1063,7 +1067,7 @@ func runAndWaitForServerReady(server *Server) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
func Test_NewServer(t *testing.T) {
|
||||
paramtable.Init()
|
||||
ctx := context.Background()
|
||||
|
@ -1212,6 +1216,11 @@ func Test_NewServer(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Upsert", func(t *testing.T) {
|
||||
_, err := server.Upsert(ctx, nil)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Search", func(t *testing.T) {
|
||||
_, err := server.Search(ctx, nil)
|
||||
assert.Nil(t, err)
|
||||
|
|
|
@ -35,6 +35,7 @@ const (
|
|||
|
||||
InsertLabel = "insert"
|
||||
DeleteLabel = "delete"
|
||||
UpsertLabel = "upsert"
|
||||
SearchLabel = "search"
|
||||
QueryLabel = "query"
|
||||
CacheHitLabel = "hit"
|
||||
|
|
|
@ -380,6 +380,12 @@ func (dt *DeleteMsg) CheckAligned() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ///////////////////////////////////////Upsert//////////////////////////////////////////
|
||||
type UpsertMsg struct {
|
||||
InsertMsg *InsertMsg
|
||||
DeleteMsg *DeleteMsg
|
||||
}
|
||||
|
||||
/////////////////////////////////////////TimeTick//////////////////////////////////////////
|
||||
|
||||
// TimeTickMsg is a message pack that contains time tick only
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
|
@ -2018,11 +2019,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
}
|
||||
method := "Insert"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
receiveSize := proto.Size(request)
|
||||
rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(receiveSize))
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Add(float64(receiveSize))
|
||||
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Add(float64(proto.Size(request)))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
|
||||
it := &insertTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
|
@ -2119,6 +2118,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
// InsertCnt always equals to the number of entities in the request
|
||||
it.result.InsertCnt = int64(request.NumRows)
|
||||
|
||||
receiveSize := proto.Size(it.insertMsg)
|
||||
rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(receiveSize))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
successCnt := it.result.InsertCnt - int64(len(it.result.ErrIndex))
|
||||
|
@ -2128,10 +2130,6 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
return it.result, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
||||
panic("TODO: not implement")
|
||||
}
|
||||
|
||||
// Delete delete records from collection, then these records cannot be searched.
|
||||
func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Delete")
|
||||
|
@ -2140,9 +2138,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
|||
log.Debug("Start processing delete request in Proxy")
|
||||
defer log.Debug("Finish processing delete request in Proxy")
|
||||
|
||||
receiveSize := proto.Size(request)
|
||||
rateCol.Add(internalpb.RateType_DMLDelete.String(), float64(receiveSize))
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Add(float64(receiveSize))
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Add(float64(proto.Size(request)))
|
||||
|
||||
if !node.checkHealthy() {
|
||||
return &milvuspb.MutationResult{
|
||||
|
@ -2219,6 +2215,9 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
|||
}, nil
|
||||
}
|
||||
|
||||
receiveSize := proto.Size(dt.deleteMsg)
|
||||
rateCol.Add(internalpb.RateType_DMLDelete.String(), float64(receiveSize))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
@ -2226,6 +2225,140 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
|||
return dt.result, nil
|
||||
}
|
||||
|
||||
// Upsert upsert records into collection.
|
||||
func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Upsert")
|
||||
defer sp.Finish()
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName),
|
||||
zap.Uint32("NumRows", request.NumRows),
|
||||
)
|
||||
log.Debug("Start processing upsert request in Proxy")
|
||||
|
||||
if !node.checkHealthy() {
|
||||
return &milvuspb.MutationResult{
|
||||
Status: unhealthyStatus(),
|
||||
}, nil
|
||||
}
|
||||
method := "Upsert"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Add(float64(proto.Size(request)))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
|
||||
it := &upsertTask{
|
||||
baseMsg: msgstream.BaseMsg{
|
||||
HashValues: request.HashKeys,
|
||||
},
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
|
||||
req: &milvuspb.UpsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType(commonpb.MsgType_Upsert)),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
CollectionName: request.CollectionName,
|
||||
PartitionName: request.PartitionName,
|
||||
FieldsData: request.FieldsData,
|
||||
NumRows: request.NumRows,
|
||||
},
|
||||
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
IDs: &schemapb.IDs{
|
||||
IdField: nil,
|
||||
},
|
||||
},
|
||||
|
||||
idAllocator: node.rowIDAllocator,
|
||||
segIDAssigner: node.segAssigner,
|
||||
chMgr: node.chMgr,
|
||||
chTicker: node.chTicker,
|
||||
}
|
||||
|
||||
if len(it.req.PartitionName) <= 0 {
|
||||
it.req.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue()
|
||||
}
|
||||
|
||||
constructFailedResponse := func(err error, errCode commonpb.ErrorCode) *milvuspb.MutationResult {
|
||||
numRows := request.NumRows
|
||||
errIndex := make([]uint32, numRows)
|
||||
for i := uint32(0); i < numRows; i++ {
|
||||
errIndex[i] = i
|
||||
}
|
||||
|
||||
return &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: errCode,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
ErrIndex: errIndex,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Enqueue upsert request in Proxy",
|
||||
zap.Int("len(FieldsData)", len(request.FieldsData)),
|
||||
zap.Int("len(HashKeys)", len(request.HashKeys)))
|
||||
|
||||
if err := node.sched.dmQueue.Enqueue(it); err != nil {
|
||||
log.Info("Failed to enqueue upsert task",
|
||||
zap.Error(err))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.AbandonLabel).Inc()
|
||||
return &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Detail of upsert request in Proxy",
|
||||
zap.Uint64("BeginTS", it.BeginTs()),
|
||||
zap.Uint64("EndTS", it.EndTs()))
|
||||
|
||||
if err := it.WaitToFinish(); err != nil {
|
||||
log.Info("Failed to execute insert task in task scheduler",
|
||||
zap.Error(err))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.FailLabel).Inc()
|
||||
return constructFailedResponse(err, it.result.Status.ErrorCode), nil
|
||||
}
|
||||
|
||||
if it.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
setErrorIndex := func() {
|
||||
numRows := request.NumRows
|
||||
errIndex := make([]uint32, numRows)
|
||||
for i := uint32(0); i < numRows; i++ {
|
||||
errIndex[i] = i
|
||||
}
|
||||
it.result.ErrIndex = errIndex
|
||||
}
|
||||
setErrorIndex()
|
||||
}
|
||||
|
||||
insertReceiveSize := proto.Size(it.upsertMsg.InsertMsg)
|
||||
deleteReceiveSize := proto.Size(it.upsertMsg.DeleteMsg)
|
||||
|
||||
rateCol.Add(internalpb.RateType_DMLDelete.String(), float64(deleteReceiveSize))
|
||||
rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(insertReceiveSize))
|
||||
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
metrics.ProxyCollectionMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
log.Debug("Finish processing upsert request in Proxy")
|
||||
return it.result, nil
|
||||
}
|
||||
|
||||
// Search search the most similar records of requests.
|
||||
func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
receiveSize := proto.Size(request)
|
||||
|
|
|
@ -691,6 +691,20 @@ func TestProxy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
constructCollectionUpsertRequest := func() *milvuspb.UpsertRequest {
|
||||
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
||||
hashKeys := generateHashKeys(rowNum)
|
||||
return &milvuspb.UpsertRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
FieldsData: []*schemapb.FieldData{fVecColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
}
|
||||
}
|
||||
|
||||
constructCreateIndexRequest := func() *milvuspb.CreateIndexRequest {
|
||||
return &milvuspb.CreateIndexRequest{
|
||||
Base: nil,
|
||||
|
@ -1085,7 +1099,7 @@ func TestProxy(t *testing.T) {
|
|||
assert.Equal(t, int64(rowNum), resp.InsertCnt)
|
||||
})
|
||||
|
||||
// TODO(dragondriver): proxy.Delete()
|
||||
//TODO(dragondriver): proxy.Delete()
|
||||
|
||||
flushed := true
|
||||
wg.Add(1)
|
||||
|
@ -2155,6 +2169,19 @@ func TestProxy(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("upsert when autoID == true", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := constructCollectionUpsertRequest()
|
||||
|
||||
resp, err := proxy.Upsert(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UpsertAutoIDTrue, resp.Status.ErrorCode)
|
||||
assert.Equal(t, 0, len(resp.SuccIndex))
|
||||
assert.Equal(t, rowNum, len(resp.ErrIndex))
|
||||
assert.Equal(t, int64(0), resp.UpsertCnt)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("drop collection", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
|
@ -2607,6 +2634,14 @@ func TestProxy(t *testing.T) {
|
|||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("Upsert fail, unhealthy", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("Search fail, unhealthy", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
|
@ -2985,6 +3020,14 @@ func TestProxy(t *testing.T) {
|
|||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("Upsert fail, dm queue full", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
proxy.sched.dmQueue.setMaxTaskNum(dmParallelism)
|
||||
|
||||
dqParallelism := proxy.sched.dqQueue.getMaxTaskNum()
|
||||
|
@ -3220,6 +3263,14 @@ func TestProxy(t *testing.T) {
|
|||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("Update fail, timeout", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
resp, err := proxy.Upsert(shortCtx, &milvuspb.UpsertRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("Search fail, timeout", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
|
@ -3295,6 +3346,161 @@ func TestProxy(t *testing.T) {
|
|||
testProxyRoleTimeout(shortCtx, t, proxy)
|
||||
testProxyPrivilegeTimeout(shortCtx, t, proxy)
|
||||
|
||||
constructCollectionSchema = func() *schemapb.CollectionSchema {
|
||||
pk := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: int64Field,
|
||||
IsPrimaryKey: true,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: nil,
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
fVec := &schemapb.FieldSchema{
|
||||
FieldID: 0,
|
||||
Name: floatVecField,
|
||||
IsPrimaryKey: false,
|
||||
Description: "",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: strconv.Itoa(dim),
|
||||
},
|
||||
},
|
||||
IndexParams: nil,
|
||||
AutoID: false,
|
||||
}
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
pk,
|
||||
fVec,
|
||||
},
|
||||
}
|
||||
}
|
||||
schema = constructCollectionSchema()
|
||||
|
||||
constructCreateCollectionRequest = func() *milvuspb.CreateCollectionRequest {
|
||||
bs, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
return &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: bs,
|
||||
ShardsNum: shardsNum,
|
||||
}
|
||||
}
|
||||
createCollectionReq = constructCreateCollectionRequest()
|
||||
|
||||
constructPartitionReqUpsertRequestValid := func() *milvuspb.UpsertRequest {
|
||||
pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
|
||||
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
||||
hashKeys := generateHashKeys(rowNum)
|
||||
return &milvuspb.UpsertRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
}
|
||||
}
|
||||
|
||||
constructCollectionUpsertRequestValid := func() *milvuspb.UpsertRequest {
|
||||
pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
|
||||
fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
||||
hashKeys := generateHashKeys(rowNum)
|
||||
return &milvuspb.UpsertRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("create collection upsert valid", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := createCollectionReq
|
||||
resp, err := proxy.CreateCollection(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
|
||||
reqInvalidField := constructCreateCollectionRequest()
|
||||
schema := constructCollectionSchema()
|
||||
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
Name: "StringField",
|
||||
DataType: schemapb.DataType_String,
|
||||
})
|
||||
bs, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
reqInvalidField.CollectionName = "invalid_field_coll_upsert_valid"
|
||||
reqInvalidField.Schema = bs
|
||||
|
||||
resp, err = proxy.CreateCollection(ctx, reqInvalidField)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("create partition", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
resp, err := proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
|
||||
// create partition with non-exist collection -> fail
|
||||
resp, err = proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: otherCollectionName,
|
||||
PartitionName: partitionName,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("upsert partition", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := constructPartitionReqUpsertRequestValid()
|
||||
|
||||
resp, err := proxy.Upsert(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, rowNum, len(resp.SuccIndex))
|
||||
assert.Equal(t, 0, len(resp.ErrIndex))
|
||||
assert.Equal(t, int64(rowNum), resp.UpsertCnt)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("upsert when autoID == false", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
req := constructCollectionUpsertRequestValid()
|
||||
|
||||
resp, err := proxy.Upsert(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
assert.Equal(t, rowNum, len(resp.SuccIndex))
|
||||
assert.Equal(t, 0, len(resp.ErrIndex))
|
||||
assert.Equal(t, int64(rowNum), resp.UpsertCnt)
|
||||
})
|
||||
|
||||
testServer.gracefulStop()
|
||||
|
||||
wg.Wait()
|
||||
|
|
|
@ -65,11 +65,12 @@ const (
|
|||
ReleaseCollectionTaskName = "ReleaseCollectionTask"
|
||||
LoadPartitionTaskName = "LoadPartitionsTask"
|
||||
ReleasePartitionTaskName = "ReleasePartitionsTask"
|
||||
deleteTaskName = "DeleteTask"
|
||||
DeleteTaskName = "DeleteTask"
|
||||
CreateAliasTaskName = "CreateAliasTask"
|
||||
DropAliasTaskName = "DropAliasTask"
|
||||
AlterAliasTaskName = "AlterAliasTask"
|
||||
AlterCollectionTaskName = "AlterCollectionTask"
|
||||
UpsertTaskName = "UpsertTask"
|
||||
|
||||
// minFloat32 minimum float.
|
||||
minFloat32 = -1 * float32(math.MaxFloat32)
|
||||
|
|
|
@ -58,7 +58,7 @@ func (dt *deleteTask) Type() commonpb.MsgType {
|
|||
}
|
||||
|
||||
func (dt *deleteTask) Name() string {
|
||||
return deleteTaskName
|
||||
return DeleteTaskName
|
||||
}
|
||||
|
||||
func (dt *deleteTask) BeginTs() Timestamp {
|
||||
|
|
|
@ -126,12 +126,12 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||
if err != nil {
|
||||
log.Error("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.schema = collSchema
|
||||
it.schema = schema
|
||||
|
||||
rowNums := uint32(it.insertMsg.NRows())
|
||||
// set insertTask.rowIDs
|
||||
|
@ -172,7 +172,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// set field ID to insert field data
|
||||
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), collSchema)
|
||||
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), schema)
|
||||
if err != nil {
|
||||
log.Error("set fieldID to fieldData failed",
|
||||
zap.Error(err))
|
||||
|
|
|
@ -1698,6 +1698,94 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
assert.NoError(t, task.PostExecute(ctx))
|
||||
})
|
||||
|
||||
t.Run("upsert", func(t *testing.T) {
|
||||
hash := generateHashKeys(nb)
|
||||
task := &upsertTask{
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &BaseInsertTask{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: hash,
|
||||
},
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
NumRows: uint64(nb),
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
},
|
||||
},
|
||||
DeleteMsg: &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: hash,
|
||||
},
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Delete,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
Condition: NewTaskCondition(ctx),
|
||||
req: &milvuspb.UpsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 0,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
HashKeys: hash,
|
||||
NumRows: uint32(nb),
|
||||
},
|
||||
ctx: ctx,
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
IDs: nil,
|
||||
SuccIndex: nil,
|
||||
ErrIndex: nil,
|
||||
Acknowledged: false,
|
||||
InsertCnt: 0,
|
||||
DeleteCnt: 0,
|
||||
UpsertCnt: 0,
|
||||
Timestamp: 0,
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
fieldID := common.StartOfUserFieldID
|
||||
for fieldName, dataType := range fieldName2Types {
|
||||
task.req.FieldsData = append(task.req.FieldsData, generateFieldData(dataType, fieldName, nb))
|
||||
fieldID++
|
||||
}
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
assert.NoError(t, task.PreExecute(ctx))
|
||||
assert.NoError(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
})
|
||||
|
||||
t.Run("delete", func(t *testing.T) {
|
||||
task := &deleteTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
|
|
|
@ -0,0 +1,506 @@
|
|||
// // Licensed to the LF AI & Data foundation under one
|
||||
// // or more contributor license agreements. See the NOTICE file
|
||||
// // distributed with this work for additional information
|
||||
// // regarding copyright ownership. The ASF licenses this file
|
||||
// // to you under the Apache License, Version 2.0 (the
|
||||
// // "License"); you may not use this file except in compliance
|
||||
// // with the License. You may obtain a copy of the License at
|
||||
// //
|
||||
// // http://www.apache.org/licenses/LICENSE-2.0
|
||||
// //
|
||||
// // Unless required by applicable law or agreed to in writing, software
|
||||
// // distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// // See the License for the specific language governing permissions and
|
||||
// // limitations under the License.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type upsertTask struct {
|
||||
Condition
|
||||
|
||||
upsertMsg *msgstream.UpsertMsg
|
||||
req *milvuspb.UpsertRequest
|
||||
baseMsg msgstream.BaseMsg
|
||||
|
||||
ctx context.Context
|
||||
|
||||
timestamps []uint64
|
||||
rowIDs []int64
|
||||
result *milvuspb.MutationResult
|
||||
idAllocator *allocator.IDAllocator
|
||||
segIDAssigner *segIDAssigner
|
||||
collectionID UniqueID
|
||||
chMgr channelsMgr
|
||||
chTicker channelsTimeTicker
|
||||
vChannels []vChan
|
||||
pChannels []pChan
|
||||
schema *schemapb.CollectionSchema
|
||||
}
|
||||
|
||||
// TraceCtx returns upsertTask context
|
||||
func (it *upsertTask) TraceCtx() context.Context {
|
||||
return it.ctx
|
||||
}
|
||||
|
||||
func (it *upsertTask) ID() UniqueID {
|
||||
return it.req.Base.MsgID
|
||||
}
|
||||
|
||||
func (it *upsertTask) SetID(uid UniqueID) {
|
||||
it.req.Base.MsgID = uid
|
||||
}
|
||||
|
||||
func (it *upsertTask) Name() string {
|
||||
return UpsertTaskName
|
||||
}
|
||||
|
||||
func (it *upsertTask) Type() commonpb.MsgType {
|
||||
return it.req.Base.MsgType
|
||||
}
|
||||
|
||||
func (it *upsertTask) BeginTs() Timestamp {
|
||||
return it.baseMsg.BeginTimestamp
|
||||
}
|
||||
|
||||
func (it *upsertTask) SetTs(ts Timestamp) {
|
||||
it.baseMsg.BeginTimestamp = ts
|
||||
it.baseMsg.EndTimestamp = ts
|
||||
}
|
||||
|
||||
func (it *upsertTask) EndTs() Timestamp {
|
||||
return it.baseMsg.EndTimestamp
|
||||
}
|
||||
|
||||
func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
||||
ret := make(map[pChan]pChanStatistics)
|
||||
|
||||
channels, err := it.getChannels()
|
||||
if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
|
||||
beginTs := it.BeginTs()
|
||||
endTs := it.EndTs()
|
||||
|
||||
for _, channel := range channels {
|
||||
ret[channel] = pChanStatistics{
|
||||
minTs: beginTs,
|
||||
maxTs: endTs,
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) getChannels() ([]pChan, error) {
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return it.chMgr.getChannels(collID)
|
||||
}
|
||||
|
||||
func (it *upsertTask) OnEnqueue() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
collectionName := it.upsertMsg.InsertMsg.CollectionName
|
||||
if err := validateCollectionName(collectionName); err != nil {
|
||||
log.Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
partitionTag := it.upsertMsg.InsertMsg.PartitionName
|
||||
if err := validatePartitionTag(partitionTag, true); err != nil {
|
||||
log.Error("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
|
||||
// set upsertTask.insertRequest.rowIDs
|
||||
tr := timerecord.NewTimeRecorder("applyPK")
|
||||
rowIDBegin, rowIDEnd, _ := it.idAllocator.Alloc(rowNums)
|
||||
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan()))
|
||||
|
||||
it.upsertMsg.InsertMsg.RowIDs = make([]UniqueID, rowNums)
|
||||
it.rowIDs = make([]UniqueID, rowNums)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
offset := i - rowIDBegin
|
||||
it.upsertMsg.InsertMsg.RowIDs[offset] = i
|
||||
it.rowIDs[offset] = i
|
||||
}
|
||||
// set upsertTask.insertRequest.timeStamps
|
||||
rowNum := it.upsertMsg.InsertMsg.NRows()
|
||||
it.upsertMsg.InsertMsg.Timestamps = make([]uint64, rowNum)
|
||||
it.timestamps = make([]uint64, rowNum)
|
||||
for index := range it.timestamps {
|
||||
it.upsertMsg.InsertMsg.Timestamps[index] = it.BeginTs()
|
||||
it.timestamps[index] = it.BeginTs()
|
||||
}
|
||||
// set result.SuccIndex
|
||||
sliceIndex := make([]uint32, rowNums)
|
||||
for i := uint32(0); i < rowNums; i++ {
|
||||
sliceIndex[i] = i
|
||||
}
|
||||
it.result.SuccIndex = sliceIndex
|
||||
|
||||
// check primaryFieldData whether autoID is true or not
|
||||
// only allow support autoID == false
|
||||
var err error
|
||||
it.result.IDs, err = upsertCheckPrimaryFieldData(it.schema, it.result, it.upsertMsg.InsertMsg)
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
|
||||
if err != nil {
|
||||
log.Error("check primary field data and hash primary key failed when upsert",
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
// set field ID to insert field data
|
||||
err = fillFieldIDBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
|
||||
if err != nil {
|
||||
log.Error("insert set fieldID to fieldData failed when upsert",
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// check that all field's number rows are equal
|
||||
if err = it.upsertMsg.InsertMsg.CheckAligned(); err != nil {
|
||||
log.Error("field data is not aligned when upsert",
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("Proxy Upsert insertPreExecute done")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) deletePreExecute(ctx context.Context) error {
|
||||
collName := it.upsertMsg.DeleteMsg.CollectionName
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("collectionName", collName))
|
||||
|
||||
if err := validateCollectionName(collName); err != nil {
|
||||
log.Info("Invalid collection name", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, collName)
|
||||
if err != nil {
|
||||
log.Info("Failed to get collection id", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.upsertMsg.DeleteMsg.CollectionID = collID
|
||||
it.collectionID = collID
|
||||
|
||||
// If partitionName is not empty, partitionID will be set.
|
||||
if len(it.upsertMsg.DeleteMsg.PartitionName) > 0 {
|
||||
partName := it.upsertMsg.DeleteMsg.PartitionName
|
||||
if err := validatePartitionTag(partName, true); err != nil {
|
||||
log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
partID, err := globalMetaCache.GetPartitionID(ctx, collName, partName)
|
||||
if err != nil {
|
||||
log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.upsertMsg.DeleteMsg.PartitionID = partID
|
||||
} else {
|
||||
it.upsertMsg.DeleteMsg.PartitionID = common.InvalidPartitionID
|
||||
}
|
||||
|
||||
it.upsertMsg.DeleteMsg.Timestamps = make([]uint64, it.upsertMsg.DeleteMsg.NumRows)
|
||||
for index := range it.upsertMsg.DeleteMsg.Timestamps {
|
||||
it.upsertMsg.DeleteMsg.Timestamps[index] = it.BeginTs()
|
||||
}
|
||||
log.Debug("Proxy Upsert deletePreExecute done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) PreExecute(ctx context.Context) error {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Upsert-PreExecute")
|
||||
defer sp.Finish()
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
|
||||
|
||||
it.req.Base = commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType(commonpb.MsgType_Upsert)),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
)
|
||||
|
||||
it.result = &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
IDs: &schemapb.IDs{
|
||||
IdField: nil,
|
||||
},
|
||||
Timestamp: it.EndTs(),
|
||||
}
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.CollectionName)
|
||||
if err != nil {
|
||||
log.Info("Failed to get collection schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.schema = schema
|
||||
|
||||
it.upsertMsg = &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
CollectionName: it.req.CollectionName,
|
||||
PartitionName: it.req.PartitionName,
|
||||
FieldsData: it.req.FieldsData,
|
||||
NumRows: uint64(it.req.NumRows),
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
},
|
||||
},
|
||||
DeleteMsg: &msgstream.DeleteMsg{
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
DbName: it.req.DbName,
|
||||
CollectionName: it.req.CollectionName,
|
||||
NumRows: int64(it.req.NumRows),
|
||||
PartitionName: it.req.PartitionName,
|
||||
CollectionID: it.collectionID,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = it.insertPreExecute(ctx)
|
||||
if err != nil {
|
||||
log.Info("Fail to insertPreExecute", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
err = it.deletePreExecute(ctx)
|
||||
if err != nil {
|
||||
log.Info("Fail to deletePreExecute", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
it.result.DeleteCnt = it.upsertMsg.DeleteMsg.NumRows
|
||||
it.result.InsertCnt = int64(it.upsertMsg.InsertMsg.NumRows)
|
||||
if it.result.DeleteCnt != it.result.InsertCnt {
|
||||
log.Error("DeleteCnt and InsertCnt are not the same when upsert",
|
||||
zap.Int64("DeleteCnt", it.result.DeleteCnt),
|
||||
zap.Int64("InsertCnt", it.result.InsertCnt))
|
||||
}
|
||||
it.result.UpsertCnt = it.result.InsertCnt
|
||||
log.Debug("Proxy Upsert PreExecute done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgPack) error {
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", it.ID()))
|
||||
defer tr.Elapse("insert execute done when insertExecute")
|
||||
|
||||
collectionName := it.upsertMsg.InsertMsg.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
it.upsertMsg.InsertMsg.CollectionID = collID
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", collID))
|
||||
var partitionID UniqueID
|
||||
if len(it.upsertMsg.InsertMsg.PartitionName) > 0 {
|
||||
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.req.PartitionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, Params.CommonCfg.DefaultPartitionName.GetValue())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
it.upsertMsg.InsertMsg.PartitionID = partitionID
|
||||
tr.Record("get collection id & partition id from cache when insertExecute")
|
||||
|
||||
_, err = it.chMgr.getOrCreateDmlStream(collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tr.Record("get used message stream when insertExecute")
|
||||
|
||||
channelNames, err := it.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
log.Error("get vChannels failed when insertExecute",
|
||||
zap.Error(err))
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("send insert request to virtual channels when insertExecute",
|
||||
zap.String("collection", it.req.GetCollectionName()),
|
||||
zap.String("partition", it.req.GetPartitionName()),
|
||||
zap.Int64("collection_id", collID),
|
||||
zap.Int64("partition_id", partitionID),
|
||||
zap.Strings("virtual_channels", channelNames),
|
||||
zap.Int64("task_id", it.ID()))
|
||||
|
||||
// assign segmentID for insert data and repack data by segmentID
|
||||
insertMsgPack, err := assignSegmentID(it.TraceCtx(), it.upsertMsg.InsertMsg, it.result, channelNames, it.idAllocator, it.segIDAssigner)
|
||||
if err != nil {
|
||||
log.Error("assign segmentID and repack insert data failed when insertExecute",
|
||||
zap.Error(err))
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
log.Debug("assign segmentID for insert data success when insertExecute",
|
||||
zap.String("collectionName", it.req.CollectionName))
|
||||
tr.Record("assign segment id")
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsgPack.Msgs...)
|
||||
|
||||
log.Debug("Proxy Insert Execute done when upsert",
|
||||
zap.String("collectionName", collectionName))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgPack) (err error) {
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
|
||||
defer tr.Elapse("delete execute done when upsert")
|
||||
|
||||
collID := it.upsertMsg.DeleteMsg.CollectionID
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", collID))
|
||||
// hash primary keys to channels
|
||||
channelNames, err := it.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
log.Warn("get vChannels failed when deleteExecute", zap.Error(err))
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
it.upsertMsg.DeleteMsg.PrimaryKeys = it.result.IDs
|
||||
it.upsertMsg.DeleteMsg.HashValues = typeutil.HashPK2Channels(it.upsertMsg.DeleteMsg.PrimaryKeys, channelNames)
|
||||
|
||||
log.Debug("send delete request to virtual channels when deleteExecute",
|
||||
zap.Int64("collection_id", collID),
|
||||
zap.Strings("virtual_channels", channelNames))
|
||||
|
||||
tr.Record("get vchannels")
|
||||
// repack delete msg by dmChannel
|
||||
result := make(map[uint32]msgstream.TsMsg)
|
||||
collectionName := it.upsertMsg.DeleteMsg.CollectionName
|
||||
collectionID := it.upsertMsg.DeleteMsg.CollectionID
|
||||
partitionID := it.upsertMsg.DeleteMsg.PartitionID
|
||||
partitionName := it.upsertMsg.DeleteMsg.PartitionName
|
||||
proxyID := it.upsertMsg.DeleteMsg.Base.SourceID
|
||||
for index, key := range it.upsertMsg.DeleteMsg.HashValues {
|
||||
ts := it.upsertMsg.DeleteMsg.Timestamps[index]
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
sliceRequest := internalpb.DeleteRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
|
||||
commonpbutil.WithTimeStamp(ts),
|
||||
commonpbutil.WithSourceID(proxyID),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
PrimaryKeys: &schemapb.IDs{},
|
||||
}
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
},
|
||||
DeleteRequest: sliceRequest,
|
||||
}
|
||||
result[key] = deleteMsg
|
||||
}
|
||||
curMsg := result[key].(*msgstream.DeleteMsg)
|
||||
curMsg.HashValues = append(curMsg.HashValues, it.upsertMsg.DeleteMsg.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
|
||||
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
|
||||
curMsg.NumRows++
|
||||
}
|
||||
|
||||
// send delete request to log broker
|
||||
deleteMsgPack := &msgstream.MsgPack{
|
||||
BeginTs: it.upsertMsg.DeleteMsg.BeginTs(),
|
||||
EndTs: it.upsertMsg.DeleteMsg.EndTs(),
|
||||
}
|
||||
for _, msg := range result {
|
||||
if msg != nil {
|
||||
deleteMsgPack.Msgs = append(deleteMsgPack.Msgs, msg)
|
||||
}
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, deleteMsgPack.Msgs...)
|
||||
|
||||
log.Debug("Proxy Upsert deleteExecute done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) Execute(ctx context.Context) (err error) {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Upsert-Execute")
|
||||
defer sp.Finish()
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
|
||||
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(it.collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: it.BeginTs(),
|
||||
EndTs: it.EndTs(),
|
||||
}
|
||||
err = it.insertExecute(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Info("Fail to insertExecute", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
err = it.deleteExecute(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Info("Fail to deleteExecute", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
tr.Record("pack messages in upsert")
|
||||
err = stream.Produce(msgPack)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
sendMsgDur := tr.Record("send upsert request to dml channels")
|
||||
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
|
||||
log.Debug("Proxy Upsert Execute done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,684 @@
|
|||
// // Licensed to the LF AI & Data foundation under one
|
||||
// // or more contributor license agreements. See the NOTICE file
|
||||
// // distributed with this work for additional information
|
||||
// // regarding copyright ownership. The ASF licenses this file
|
||||
// // to you under the Apache License, Version 2.0 (the
|
||||
// // "License"); you may not use this file except in compliance
|
||||
// // with the License. You may obtain a copy of the License at
|
||||
// //
|
||||
// // http://www.apache.org/licenses/LICENSE-2.0
|
||||
// //
|
||||
// // Unless required by applicable law or agreed to in writing, software
|
||||
// // distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// // See the License for the specific language governing permissions and
|
||||
// // limitations under the License.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUpsertTask_CheckAligned(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// passed NumRows is less than 0
|
||||
case1 := upsertTask{
|
||||
req: &milvuspb.UpsertRequest{
|
||||
NumRows: 0,
|
||||
},
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: internalpb.InsertRequest{},
|
||||
},
|
||||
},
|
||||
}
|
||||
case1.upsertMsg.InsertMsg.InsertRequest = internalpb.InsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
||||
),
|
||||
CollectionName: case1.req.CollectionName,
|
||||
PartitionName: case1.req.PartitionName,
|
||||
FieldsData: case1.req.FieldsData,
|
||||
NumRows: uint64(case1.req.NumRows),
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
}
|
||||
|
||||
err = case1.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// checkLengthOfFieldsData was already checked by TestUpsertTask_checkLengthOfFieldsData
|
||||
|
||||
boolFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}
|
||||
int8FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}
|
||||
int16FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}
|
||||
int32FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}
|
||||
int64FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}
|
||||
floatFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float}
|
||||
doubleFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Double}
|
||||
floatVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
|
||||
binaryVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BinaryVector}
|
||||
varCharFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}
|
||||
|
||||
numRows := 20
|
||||
dim := 128
|
||||
case2 := upsertTask{
|
||||
req: &milvuspb.UpsertRequest{
|
||||
NumRows: uint32(numRows),
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
},
|
||||
rowIDs: generateInt64Array(numRows),
|
||||
timestamps: generateUint64Array(numRows),
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestUpsertTask_checkRowNums",
|
||||
Description: "TestUpsertTask_checkRowNums",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
boolFieldSchema,
|
||||
int8FieldSchema,
|
||||
int16FieldSchema,
|
||||
int32FieldSchema,
|
||||
int64FieldSchema,
|
||||
floatFieldSchema,
|
||||
doubleFieldSchema,
|
||||
floatVectorFieldSchema,
|
||||
binaryVectorFieldSchema,
|
||||
varCharFieldSchema,
|
||||
},
|
||||
},
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: internalpb.InsertRequest{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// satisfied
|
||||
case2.req.FieldsData = []*schemapb.FieldData{
|
||||
newScalarFieldData(boolFieldSchema, "Bool", numRows),
|
||||
newScalarFieldData(int8FieldSchema, "Int8", numRows),
|
||||
newScalarFieldData(int16FieldSchema, "Int16", numRows),
|
||||
newScalarFieldData(int32FieldSchema, "Int32", numRows),
|
||||
newScalarFieldData(int64FieldSchema, "Int64", numRows),
|
||||
newScalarFieldData(floatFieldSchema, "Float", numRows),
|
||||
newScalarFieldData(doubleFieldSchema, "Double", numRows),
|
||||
newFloatVectorFieldData("FloatVector", numRows, dim),
|
||||
newBinaryVectorFieldData("BinaryVector", numRows, dim),
|
||||
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
|
||||
}
|
||||
case2.upsertMsg.InsertMsg.InsertRequest = internalpb.InsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
||||
),
|
||||
CollectionName: case2.req.CollectionName,
|
||||
PartitionName: case2.req.PartitionName,
|
||||
FieldsData: case2.req.FieldsData,
|
||||
NumRows: uint64(case2.req.NumRows),
|
||||
RowIDs: case2.rowIDs,
|
||||
Timestamps: case2.timestamps,
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
}
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less bool data
|
||||
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more bool data
|
||||
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int8 data
|
||||
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int8 data
|
||||
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int16 data
|
||||
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int16 data
|
||||
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int32 data
|
||||
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int32 data
|
||||
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int64 data
|
||||
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int64 data
|
||||
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less float data
|
||||
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more float data
|
||||
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, nil, err)
|
||||
|
||||
// less double data
|
||||
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more double data
|
||||
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, nil, err)
|
||||
|
||||
// less float vectors
|
||||
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more float vectors
|
||||
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less binary vectors
|
||||
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more binary vectors
|
||||
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less double data
|
||||
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more double data
|
||||
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// func TestProxyUpsertValid(t *testing.T) {
|
||||
// var err error
|
||||
// var wg sync.WaitGroup
|
||||
// paramtable.Init()
|
||||
|
||||
// path := "/tmp/milvus/rocksmq" + funcutil.GenRandomStr()
|
||||
// t.Setenv("ROCKSMQ_PATH", path)
|
||||
// defer os.RemoveAll(path)
|
||||
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// ctx = GetContext(ctx, "root:123456")
|
||||
// localMsg := true
|
||||
// factory := dependency.NewDefaultFactory(localMsg)
|
||||
// alias := "TestProxyUpsertValid"
|
||||
|
||||
// log.Info("Initialize parameter table of Proxy")
|
||||
|
||||
// rc := runRootCoord(ctx, localMsg)
|
||||
// log.Info("running RootCoord ...")
|
||||
|
||||
// if rc != nil {
|
||||
// defer func() {
|
||||
// err := rc.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("stop RootCoord")
|
||||
// }()
|
||||
// }
|
||||
|
||||
// dc := runDataCoord(ctx, localMsg)
|
||||
// log.Info("running DataCoord ...")
|
||||
|
||||
// if dc != nil {
|
||||
// defer func() {
|
||||
// err := dc.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("stop DataCoord")
|
||||
// }()
|
||||
// }
|
||||
|
||||
// dn := runDataNode(ctx, localMsg, alias)
|
||||
// log.Info("running DataNode ...")
|
||||
|
||||
// if dn != nil {
|
||||
// defer func() {
|
||||
// err := dn.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("stop DataNode")
|
||||
// }()
|
||||
// }
|
||||
|
||||
// qc := runQueryCoord(ctx, localMsg)
|
||||
// log.Info("running QueryCoord ...")
|
||||
|
||||
// if qc != nil {
|
||||
// defer func() {
|
||||
// err := qc.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("stop QueryCoord")
|
||||
// }()
|
||||
// }
|
||||
|
||||
// qn := runQueryNode(ctx, localMsg, alias)
|
||||
// log.Info("running QueryNode ...")
|
||||
|
||||
// if qn != nil {
|
||||
// defer func() {
|
||||
// err := qn.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("stop query node")
|
||||
// }()
|
||||
// }
|
||||
|
||||
// ic := runIndexCoord(ctx, localMsg)
|
||||
// log.Info("running IndexCoord ...")
|
||||
|
||||
// if ic != nil {
|
||||
// defer func() {
|
||||
// err := ic.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("stop IndexCoord")
|
||||
// }()
|
||||
// }
|
||||
|
||||
// in := runIndexNode(ctx, localMsg, alias)
|
||||
// log.Info("running IndexNode ...")
|
||||
|
||||
// if in != nil {
|
||||
// defer func() {
|
||||
// err := in.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("stop IndexNode")
|
||||
// }()
|
||||
// }
|
||||
|
||||
// time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// proxy, err := NewProxy(ctx, factory)
|
||||
// assert.NoError(t, err)
|
||||
// assert.NotNil(t, proxy)
|
||||
|
||||
// etcdcli, err := etcd.GetEtcdClient(
|
||||
// Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
|
||||
// Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
|
||||
// Params.EtcdCfg.Endpoints.GetAsStrings(),
|
||||
// Params.EtcdCfg.EtcdTLSCert.GetValue(),
|
||||
// Params.EtcdCfg.EtcdTLSKey.GetValue(),
|
||||
// Params.EtcdCfg.EtcdTLSCACert.GetValue(),
|
||||
// Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
|
||||
// defer etcdcli.Close()
|
||||
// assert.NoError(t, err)
|
||||
// proxy.SetEtcdClient(etcdcli)
|
||||
|
||||
// testServer := newProxyTestServer(proxy)
|
||||
// wg.Add(1)
|
||||
|
||||
// base := paramtable.BaseTable{}
|
||||
// base.Init(0)
|
||||
// var p paramtable.GrpcServerConfig
|
||||
// p.Init(typeutil.ProxyRole, &base)
|
||||
|
||||
// go testServer.startGrpc(ctx, &wg, &p)
|
||||
// assert.NoError(t, testServer.waitForGrpcReady())
|
||||
|
||||
// rootCoordClient, err := rcc.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
|
||||
// assert.NoError(t, err)
|
||||
// err = rootCoordClient.Init()
|
||||
// assert.NoError(t, err)
|
||||
// err = funcutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration)
|
||||
// assert.NoError(t, err)
|
||||
// proxy.SetRootCoordClient(rootCoordClient)
|
||||
// log.Info("Proxy set root coordinator client")
|
||||
|
||||
// dataCoordClient, err := grpcdatacoordclient2.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
|
||||
// assert.NoError(t, err)
|
||||
// err = dataCoordClient.Init()
|
||||
// assert.NoError(t, err)
|
||||
// err = funcutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration)
|
||||
// assert.NoError(t, err)
|
||||
// proxy.SetDataCoordClient(dataCoordClient)
|
||||
// log.Info("Proxy set data coordinator client")
|
||||
|
||||
// queryCoordClient, err := grpcquerycoordclient.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
|
||||
// assert.NoError(t, err)
|
||||
// err = queryCoordClient.Init()
|
||||
// assert.NoError(t, err)
|
||||
// err = funcutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration)
|
||||
// assert.NoError(t, err)
|
||||
// proxy.SetQueryCoordClient(queryCoordClient)
|
||||
// log.Info("Proxy set query coordinator client")
|
||||
|
||||
// indexCoordClient, err := grpcindexcoordclient.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli)
|
||||
// assert.NoError(t, err)
|
||||
// err = indexCoordClient.Init()
|
||||
// assert.NoError(t, err)
|
||||
// err = funcutil.WaitForComponentHealthy(ctx, indexCoordClient, typeutil.IndexCoordRole, attempts, sleepDuration)
|
||||
// assert.NoError(t, err)
|
||||
// proxy.SetIndexCoordClient(indexCoordClient)
|
||||
// log.Info("Proxy set index coordinator client")
|
||||
|
||||
// proxy.UpdateStateCode(commonpb.StateCode_Initializing)
|
||||
// err = proxy.Init()
|
||||
// assert.NoError(t, err)
|
||||
|
||||
// err = proxy.Start()
|
||||
// assert.NoError(t, err)
|
||||
// assert.Equal(t, commonpb.StateCode_Healthy, proxy.stateCode.Load().(commonpb.StateCode))
|
||||
|
||||
// // register proxy
|
||||
// err = proxy.Register()
|
||||
// assert.NoError(t, err)
|
||||
// log.Info("Register proxy done")
|
||||
// defer func() {
|
||||
// err := proxy.Stop()
|
||||
// assert.NoError(t, err)
|
||||
// }()
|
||||
|
||||
// prefix := "test_proxy_"
|
||||
// partitionPrefix := "test_proxy_partition_"
|
||||
// dbName := ""
|
||||
// collectionName := prefix + funcutil.GenRandomStr()
|
||||
// otherCollectionName := collectionName + "_other_" + funcutil.GenRandomStr()
|
||||
// partitionName := partitionPrefix + funcutil.GenRandomStr()
|
||||
// // otherPartitionName := partitionPrefix + "_other_" + funcutil.GenRandomStr()
|
||||
// shardsNum := int32(2)
|
||||
// int64Field := "int64"
|
||||
// floatVecField := "fVec"
|
||||
// dim := 128
|
||||
// rowNum := 30
|
||||
// // indexName := "_default"
|
||||
// // nlist := 10
|
||||
// // nprobe := 10
|
||||
// // topk := 10
|
||||
// // add a test parameter
|
||||
// // roundDecimal := 6
|
||||
// // nq := 10
|
||||
// // expr := fmt.Sprintf("%s > 0", int64Field)
|
||||
// // var segmentIDs []int64
|
||||
|
||||
// constructCollectionSchema := func() *schemapb.CollectionSchema {
|
||||
// pk := &schemapb.FieldSchema{
|
||||
// FieldID: 0,
|
||||
// Name: int64Field,
|
||||
// IsPrimaryKey: true,
|
||||
// Description: "",
|
||||
// DataType: schemapb.DataType_Int64,
|
||||
// TypeParams: nil,
|
||||
// IndexParams: nil,
|
||||
// AutoID: false,
|
||||
// }
|
||||
// fVec := &schemapb.FieldSchema{
|
||||
// FieldID: 0,
|
||||
// Name: floatVecField,
|
||||
// IsPrimaryKey: false,
|
||||
// Description: "",
|
||||
// DataType: schemapb.DataType_FloatVector,
|
||||
// TypeParams: []*commonpb.KeyValuePair{
|
||||
// {
|
||||
// Key: "dim",
|
||||
// Value: strconv.Itoa(dim),
|
||||
// },
|
||||
// },
|
||||
// IndexParams: nil,
|
||||
// AutoID: false,
|
||||
// }
|
||||
// return &schemapb.CollectionSchema{
|
||||
// Name: collectionName,
|
||||
// Description: "",
|
||||
// AutoID: false,
|
||||
// Fields: []*schemapb.FieldSchema{
|
||||
// pk,
|
||||
// fVec,
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
// schema := constructCollectionSchema()
|
||||
|
||||
// constructCreateCollectionRequest := func() *milvuspb.CreateCollectionRequest {
|
||||
// bs, err := proto.Marshal(schema)
|
||||
// assert.NoError(t, err)
|
||||
// return &milvuspb.CreateCollectionRequest{
|
||||
// Base: nil,
|
||||
// DbName: dbName,
|
||||
// CollectionName: collectionName,
|
||||
// Schema: bs,
|
||||
// ShardsNum: shardsNum,
|
||||
// }
|
||||
// }
|
||||
// createCollectionReq := constructCreateCollectionRequest()
|
||||
|
||||
// constructPartitionReqUpsertRequestValid := func() *milvuspb.UpsertRequest {
|
||||
// pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
|
||||
// fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
||||
// hashKeys := generateHashKeys(rowNum)
|
||||
// return &milvuspb.UpsertRequest{
|
||||
// Base: nil,
|
||||
// DbName: dbName,
|
||||
// CollectionName: collectionName,
|
||||
// PartitionName: partitionName,
|
||||
// FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
|
||||
// HashKeys: hashKeys,
|
||||
// NumRows: uint32(rowNum),
|
||||
// }
|
||||
// }
|
||||
|
||||
// constructCollectionUpsertRequestValid := func() *milvuspb.UpsertRequest {
|
||||
// pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum)
|
||||
// fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim)
|
||||
// hashKeys := generateHashKeys(rowNum)
|
||||
// return &milvuspb.UpsertRequest{
|
||||
// Base: nil,
|
||||
// DbName: dbName,
|
||||
// CollectionName: collectionName,
|
||||
// PartitionName: partitionName,
|
||||
// FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn},
|
||||
// HashKeys: hashKeys,
|
||||
// NumRows: uint32(rowNum),
|
||||
// }
|
||||
// }
|
||||
|
||||
// wg.Add(1)
|
||||
// t.Run("create collection upsert valid", func(t *testing.T) {
|
||||
// defer wg.Done()
|
||||
// req := createCollectionReq
|
||||
// resp, err := proxy.CreateCollection(ctx, req)
|
||||
// assert.NoError(t, err)
|
||||
// assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
|
||||
// reqInvalidField := constructCreateCollectionRequest()
|
||||
// schema := constructCollectionSchema()
|
||||
// schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
|
||||
// Name: "StringField",
|
||||
// DataType: schemapb.DataType_String,
|
||||
// })
|
||||
// bs, err := proto.Marshal(schema)
|
||||
// assert.NoError(t, err)
|
||||
// reqInvalidField.CollectionName = "invalid_field_coll_upsert_valid"
|
||||
// reqInvalidField.Schema = bs
|
||||
|
||||
// resp, err = proxy.CreateCollection(ctx, reqInvalidField)
|
||||
// assert.NoError(t, err)
|
||||
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
|
||||
// })
|
||||
|
||||
// wg.Add(1)
|
||||
// t.Run("create partition", func(t *testing.T) {
|
||||
// defer wg.Done()
|
||||
// resp, err := proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
// Base: nil,
|
||||
// DbName: dbName,
|
||||
// CollectionName: collectionName,
|
||||
// PartitionName: partitionName,
|
||||
// })
|
||||
// assert.NoError(t, err)
|
||||
// assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
|
||||
// // create partition with non-exist collection -> fail
|
||||
// resp, err = proxy.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
|
||||
// Base: nil,
|
||||
// DbName: dbName,
|
||||
// CollectionName: otherCollectionName,
|
||||
// PartitionName: partitionName,
|
||||
// })
|
||||
// assert.NoError(t, err)
|
||||
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
// })
|
||||
|
||||
// wg.Add(1)
|
||||
// t.Run("upsert partition", func(t *testing.T) {
|
||||
// defer wg.Done()
|
||||
// req := constructPartitionReqUpsertRequestValid()
|
||||
|
||||
// resp, err := proxy.Upsert(ctx, req)
|
||||
// assert.NoError(t, err)
|
||||
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
// assert.Equal(t, rowNum, len(resp.SuccIndex))
|
||||
// assert.Equal(t, 0, len(resp.ErrIndex))
|
||||
// assert.Equal(t, int64(rowNum), resp.UpsertCnt)
|
||||
// })
|
||||
|
||||
// wg.Add(1)
|
||||
// t.Run("upsert when autoID == false", func(t *testing.T) {
|
||||
// defer wg.Done()
|
||||
// req := constructCollectionUpsertRequestValid()
|
||||
|
||||
// resp, err := proxy.Upsert(ctx, req)
|
||||
// assert.NoError(t, err)
|
||||
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
// assert.Equal(t, rowNum, len(resp.SuccIndex))
|
||||
// assert.Equal(t, 0, len(resp.ErrIndex))
|
||||
// assert.Equal(t, int64(rowNum), resp.UpsertCnt)
|
||||
// })
|
||||
|
||||
// proxy.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
|
||||
// wg.Add(1)
|
||||
// t.Run("Upsert fail, unhealthy", func(t *testing.T) {
|
||||
// defer wg.Done()
|
||||
// resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
|
||||
// assert.NoError(t, err)
|
||||
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
// })
|
||||
|
||||
// dmParallelism := proxy.sched.dmQueue.getMaxTaskNum()
|
||||
// proxy.sched.dmQueue.setMaxTaskNum(0)
|
||||
|
||||
// wg.Add(1)
|
||||
// t.Run("Upsert fail, dm queue full", func(t *testing.T) {
|
||||
// defer wg.Done()
|
||||
// resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{})
|
||||
// assert.NoError(t, err)
|
||||
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
// })
|
||||
// proxy.sched.dmQueue.setMaxTaskNum(dmParallelism)
|
||||
|
||||
// timeout := time.Nanosecond
|
||||
// shortCtx, shortCancel := context.WithTimeout(ctx, timeout)
|
||||
// defer shortCancel()
|
||||
// time.Sleep(timeout)
|
||||
|
||||
// wg.Add(1)
|
||||
// t.Run("Update fail, timeout", func(t *testing.T) {
|
||||
// defer wg.Done()
|
||||
// resp, err := proxy.Upsert(shortCtx, &milvuspb.UpsertRequest{})
|
||||
// assert.NoError(t, err)
|
||||
// assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
// })
|
||||
|
||||
// testServer.gracefulStop()
|
||||
// wg.Wait()
|
||||
// cancel()
|
||||
// }
|
|
@ -35,6 +35,7 @@ import (
|
|||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
|
@ -944,6 +945,49 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
|
|||
return ids, nil
|
||||
}
|
||||
|
||||
// TODO(smellthemoon): can merge it with checkPrimaryFieldData
|
||||
func upsertCheckPrimaryFieldData(schema *schemapb.CollectionSchema, result *milvuspb.MutationResult, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) {
|
||||
rowNums := uint32(insertMsg.NRows())
|
||||
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
||||
if insertMsg.NRows() <= 0 {
|
||||
return nil, errNumRowsLessThanOrEqualToZero(rowNums)
|
||||
}
|
||||
|
||||
if err := checkLengthOfFieldsData(schema, insertMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||
if err != nil {
|
||||
log.Error("get primary field schema failed", zap.String("collectionName", insertMsg.CollectionName), zap.Any("schema", schema), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get primaryFieldData whether autoID is true or not
|
||||
var primaryFieldData *schemapb.FieldData
|
||||
if primaryFieldSchema.AutoID {
|
||||
// upsert has not supported when autoID == true
|
||||
log.Info("can not upsert when auto id enabled",
|
||||
zap.String("primaryFieldSchemaName", primaryFieldSchema.Name))
|
||||
result.Status.ErrorCode = commonpb.ErrorCode_UpsertAutoIDTrue
|
||||
return nil, fmt.Errorf("upsert can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name)
|
||||
}
|
||||
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
|
||||
if err != nil {
|
||||
log.Error("get primary field data failed when upsert", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// parse primaryFieldData to result.IDs, and as returned primary keys
|
||||
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
||||
if err != nil {
|
||||
log.Error("parse primary field data to IDs failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func getCollectionProgress(ctx context.Context, queryCoord types.QueryCoord,
|
||||
msgBase *commonpb.MsgBase, collectionID int64) (int64, error) {
|
||||
resp, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
@ -1226,3 +1227,264 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
|
|||
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
|
||||
// schema is empty, though won't happen in system
|
||||
// num_rows(0) should be greater than 0
|
||||
case1 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestUpsertTask_checkPrimaryFieldData",
|
||||
Description: "TestUpsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
DbName: "TestUpsertTask_checkPrimaryFieldData",
|
||||
CollectionName: "TestUpsertTask_checkPrimaryFieldData",
|
||||
PartitionName: "TestUpsertTask_checkPrimaryFieldData",
|
||||
},
|
||||
},
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := upsertCheckPrimaryFieldData(case1.schema, case1.result, case1.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// the num of passed fields is less than needed
|
||||
case2 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestUpsertTask_checkPrimaryFieldData",
|
||||
Description: "TestUpsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "int64Field",
|
||||
FieldID: 1,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Name: "floatField",
|
||||
FieldID: 2,
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "int64Field",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = upsertCheckPrimaryFieldData(case2.schema, case2.result, case2.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// autoID == false, no primary field schema
|
||||
// primary field is not found
|
||||
case3 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestUpsertTask_checkPrimaryFieldData",
|
||||
Description: "TestUpsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "int64Field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Name: "floatField",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
},
|
||||
},
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = upsertCheckPrimaryFieldData(case3.schema, case3.result, case3.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// autoID == true, upsert don't support it
|
||||
case4 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestUpsertTask_checkPrimaryFieldData",
|
||||
Description: "TestUpsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "int64Field",
|
||||
FieldID: 1,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Name: "floatField",
|
||||
FieldID: 2,
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "int64Field",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
}
|
||||
case4.schema.Fields[0].IsPrimaryKey = true
|
||||
case4.schema.Fields[0].AutoID = true
|
||||
_, err = upsertCheckPrimaryFieldData(case4.schema, case4.result, case4.insertMsg)
|
||||
assert.Equal(t, commonpb.ErrorCode_UpsertAutoIDTrue, case4.result.Status.ErrorCode)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// primary field data is nil, GetPrimaryFieldData fail
|
||||
case5 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestUpsertTask_checkPrimaryFieldData",
|
||||
Description: "TestUpsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "int64Field",
|
||||
FieldID: 1,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Name: "floatField",
|
||||
FieldID: 2,
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
},
|
||||
},
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
}
|
||||
case5.schema.Fields[0].IsPrimaryKey = true
|
||||
case5.schema.Fields[0].AutoID = false
|
||||
_, err = upsertCheckPrimaryFieldData(case5.schema, case5.result, case5.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// only support DataType Int64 or VarChar as PrimaryField
|
||||
case6 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestUpsertTask_checkPrimaryFieldData",
|
||||
Description: "TestUpsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "floatVectorField",
|
||||
FieldID: 1,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
{
|
||||
Name: "floatField",
|
||||
FieldID: 2,
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "floatVectorField",
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "floatField",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
result: &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
}
|
||||
case6.schema.Fields[0].IsPrimaryKey = true
|
||||
case6.schema.Fields[0].AutoID = false
|
||||
_, err = upsertCheckPrimaryFieldData(case6.schema, case6.result, case6.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
|
|
@ -135,7 +135,7 @@ func (fdmNode *filterDmNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
|
|||
// filterInvalidDeleteMessage would filter out invalid delete messages
|
||||
func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg, loadType loadType) (*msgstream.DeleteMsg, error) {
|
||||
if err := msg.CheckAligned(); err != nil {
|
||||
return nil, fmt.Errorf("CheckAligned failed, err = %s", err)
|
||||
return nil, fmt.Errorf("DeleteMessage CheckAligned failed, err = %s", err)
|
||||
}
|
||||
|
||||
if len(msg.Timestamps) <= 0 {
|
||||
|
@ -168,7 +168,7 @@ func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg
|
|||
// filterInvalidInsertMessage would filter out invalid insert messages
|
||||
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg, loadType loadType) (*msgstream.InsertMsg, error) {
|
||||
if err := msg.CheckAligned(); err != nil {
|
||||
return nil, fmt.Errorf("CheckAligned failed, err = %s", err)
|
||||
return nil, fmt.Errorf("InsertMessage CheckAligned failed, err = %s", err)
|
||||
}
|
||||
|
||||
if len(msg.Timestamps) <= 0 {
|
||||
|
|
|
@ -1145,6 +1145,18 @@ type ProxyComponent interface {
|
|||
// error is always nil
|
||||
Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error)
|
||||
|
||||
// Upsert notifies Proxy to upsert rows
|
||||
//
|
||||
// ctx is the context to control request deadline and cancellation
|
||||
// req contains the request params, including database name(reserved), collection name, partition name(optional), fields data
|
||||
//
|
||||
// The `Status` in response struct `MutationResult` indicates if this operation is processed successfully or fail cause;
|
||||
// the `IDs` in `MutationResult` return the id list of upserted rows.
|
||||
// the `SuccIndex` in `MutationResult` return the succeed number of upserted rows.
|
||||
// the `ErrIndex` in `MutationResult` return the failed number of upsert rows.
|
||||
// error is always nil
|
||||
Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error)
|
||||
|
||||
// Search notifies Proxy to do search
|
||||
//
|
||||
// ctx is the context to control request deadline and cancellation
|
||||
|
|
Loading…
Reference in New Issue