mirror of https://github.com/milvus-io/milvus.git
parent
227889b0d0
commit
205c92e54b
6
go.sum
6
go.sum
|
@ -326,6 +326,12 @@ github.com/jarcoal/httpmock v1.0.8 h1:8kI16SoO6LQKgPE7PvQuV+YuD/inwHd7fOOe2zMbo4
|
|||
github.com/jarcoal/httpmock v1.0.8/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik=
|
||||
github.com/jawher/mow.cli v1.0.4/go.mod h1:5hQj2V8g+qYmLUVWqu4Wuja1pI57M83EChYLVZ0sMKk=
|
||||
github.com/jawher/mow.cli v1.2.0/go.mod h1:y+pcA3jBAdo/GIZx/0rFjw/K2bVEODP9rfZOfaiq8Ko=
|
||||
github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
|
||||
github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ=
|
||||
github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E=
|
||||
github.com/jhump/protoreflect v1.12.0 h1:1NQ4FpWMgn3by/n1X0fbeKEUxP1wBt7+Oitpv01HR10=
|
||||
github.com/jhump/protoreflect v1.12.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI=
|
||||
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
|
||||
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
||||
github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ=
|
||||
|
|
|
@ -25,7 +25,9 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
@ -41,7 +43,7 @@ func TestBinlogIOInterfaceMethods(t *testing.T) {
|
|||
b := &binlogIO{cm, alloc}
|
||||
t.Run("Test upload", func(t *testing.T) {
|
||||
f := &MetaFactory{}
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "uploads")
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "uploads", schemapb.DataType_Int64)
|
||||
|
||||
iData := genInsertData()
|
||||
dData := &DeleteData{
|
||||
|
@ -52,16 +54,16 @@ func TestBinlogIOInterfaceMethods(t *testing.T) {
|
|||
|
||||
p, err := b.upload(context.TODO(), 1, 10, []*InsertData{iData}, dData, meta)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 11, len(p.inPaths))
|
||||
assert.Equal(t, 3, len(p.statsPaths))
|
||||
assert.Equal(t, 12, len(p.inPaths))
|
||||
assert.Equal(t, 1, len(p.statsPaths))
|
||||
assert.Equal(t, 1, len(p.inPaths[0].GetBinlogs()))
|
||||
assert.Equal(t, 1, len(p.statsPaths[0].GetBinlogs()))
|
||||
assert.NotNil(t, p.deltaInfo)
|
||||
|
||||
p, err = b.upload(context.TODO(), 1, 10, []*InsertData{iData, iData}, dData, meta)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 11, len(p.inPaths))
|
||||
assert.Equal(t, 3, len(p.statsPaths))
|
||||
assert.Equal(t, 12, len(p.inPaths))
|
||||
assert.Equal(t, 1, len(p.statsPaths))
|
||||
assert.Equal(t, 2, len(p.inPaths[0].GetBinlogs()))
|
||||
assert.Equal(t, 2, len(p.statsPaths[0].GetBinlogs()))
|
||||
assert.NotNil(t, p.deltaInfo)
|
||||
|
@ -76,7 +78,7 @@ func TestBinlogIOInterfaceMethods(t *testing.T) {
|
|||
|
||||
t.Run("Test upload error", func(t *testing.T) {
|
||||
f := &MetaFactory{}
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "uploads")
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "uploads", schemapb.DataType_Int64)
|
||||
dData := &DeleteData{
|
||||
Pks: []int64{},
|
||||
Tss: []uint64{},
|
||||
|
@ -197,7 +199,7 @@ func TestBinlogIOInnerMethods(t *testing.T) {
|
|||
|
||||
t.Run("Test genDeltaBlobs", func(t *testing.T) {
|
||||
f := &MetaFactory{}
|
||||
meta := f.GetCollectionMeta(UniqueID(10002), "test_gen_blobs")
|
||||
meta := f.GetCollectionMeta(UniqueID(10002), "test_gen_blobs", schemapb.DataType_Int64)
|
||||
|
||||
tests := []struct {
|
||||
isvalid bool
|
||||
|
@ -246,19 +248,36 @@ func TestBinlogIOInnerMethods(t *testing.T) {
|
|||
|
||||
t.Run("Test genInsertBlobs", func(t *testing.T) {
|
||||
f := &MetaFactory{}
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs")
|
||||
tests := []struct {
|
||||
pkType schemapb.DataType
|
||||
description string
|
||||
}{
|
||||
{schemapb.DataType_Int64, "int64PrimaryField"},
|
||||
{schemapb.DataType_VarChar, "varCharPrimaryField"},
|
||||
}
|
||||
|
||||
kvs, pin, pstats, err := b.genInsertBlobs(genInsertData(), 10, 1, meta)
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs", test.pkType)
|
||||
helper, err := typeutil.CreateSchemaHelper(meta.Schema)
|
||||
assert.NoError(t, err)
|
||||
primaryKeyFieldSchema, err := helper.GetPrimaryKeyField()
|
||||
assert.NoError(t, err)
|
||||
primaryKeyFieldID := primaryKeyFieldSchema.GetFieldID()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, len(pstats))
|
||||
assert.Equal(t, 11, len(pin))
|
||||
assert.Equal(t, 14, len(kvs))
|
||||
kvs, pin, pstats, err := b.genInsertBlobs(genInsertData(), 10, 1, meta)
|
||||
|
||||
log.Debug("test paths",
|
||||
zap.Any("kvs no.", len(kvs)),
|
||||
zap.String("insert paths field0", pin[common.TimeStampField].GetBinlogs()[0].GetLogPath()),
|
||||
zap.String("stats paths field0", pstats[common.TimeStampField].GetBinlogs()[0].GetLogPath()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(pstats))
|
||||
assert.Equal(t, 12, len(pin))
|
||||
assert.Equal(t, 13, len(kvs))
|
||||
|
||||
log.Debug("test paths",
|
||||
zap.Any("kvs no.", len(kvs)),
|
||||
zap.String("insert paths field0", pin[common.TimeStampField].GetBinlogs()[0].GetLogPath()),
|
||||
zap.String("stats paths field0", pstats[primaryKeyFieldID].GetBinlogs()[0].GetLogPath()))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test genInsertBlobs error", func(t *testing.T) {
|
||||
|
@ -269,7 +288,7 @@ func TestBinlogIOInnerMethods(t *testing.T) {
|
|||
assert.Empty(t, pstats)
|
||||
|
||||
f := &MetaFactory{}
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs")
|
||||
meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs", schemapb.DataType_Int64)
|
||||
|
||||
kvs, pin, pstats, err = b.genInsertBlobs(genEmptyInsertData(), 10, 1, meta)
|
||||
assert.Error(t, err)
|
||||
|
|
|
@ -527,7 +527,11 @@ func (t *compactionTask) compact() error {
|
|||
}
|
||||
// no need to shorten the PK range of a segment, deleting dup PKs is valid
|
||||
} else {
|
||||
t.mergeFlushedSegments(targetSegID, collID, partID, t.plan.GetPlanID(), segIDs, t.plan.GetChannel(), numRows)
|
||||
err = t.mergeFlushedSegments(targetSegID, collID, partID, t.plan.GetPlanID(), segIDs, t.plan.GetChannel(), numRows)
|
||||
if err != nil {
|
||||
log.Error("compact wrong", zap.Int64("planID", t.plan.GetPlanID()), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
uninjectStart := time.Now()
|
||||
|
@ -660,6 +664,21 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{}
|
|||
}
|
||||
rst = data
|
||||
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
var data = &storage.StringFieldData{
|
||||
NumRows: numOfRows,
|
||||
Data: make([]string, 0, len(content)),
|
||||
}
|
||||
|
||||
for _, c := range content {
|
||||
r, ok := c.(string)
|
||||
if !ok {
|
||||
return nil, errTransferType
|
||||
}
|
||||
data.Data = append(data.Data, r)
|
||||
}
|
||||
rst = data
|
||||
|
||||
case schemapb.DataType_FloatVector:
|
||||
var data = &storage.FloatVectorFieldData{
|
||||
NumRows: numOfRows,
|
||||
|
|
|
@ -38,7 +38,9 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
|
|||
cm := storage.NewLocalChunkManager(storage.RootPath(compactTestDir))
|
||||
defer cm.RemoveWithPrefix("")
|
||||
t.Run("Test getSegmentMeta", func(t *testing.T) {
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
replica, err := newReplica(context.TODO(), rc, cm, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -80,6 +82,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
|
|||
{true, schemapb.DataType_Int64, []interface{}{int64(1), int64(2)}, "valid int64"},
|
||||
{true, schemapb.DataType_Float, []interface{}{float32(1), float32(2)}, "valid float32"},
|
||||
{true, schemapb.DataType_Double, []interface{}{float64(1), float64(2)}, "valid float64"},
|
||||
{true, schemapb.DataType_VarChar, []interface{}{"test1", "test2"}, "valid varChar"},
|
||||
{true, schemapb.DataType_FloatVector, []interface{}{[]float32{1.0, 2.0}}, "valid floatvector"},
|
||||
{true, schemapb.DataType_BinaryVector, []interface{}{[]byte{255}}, "valid binaryvector"},
|
||||
{false, schemapb.DataType_Bool, []interface{}{1, 2}, "invalid bool"},
|
||||
|
@ -89,9 +92,10 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
|
|||
{false, schemapb.DataType_Int64, []interface{}{nil, nil}, "invalid int64"},
|
||||
{false, schemapb.DataType_Float, []interface{}{nil, nil}, "invalid float32"},
|
||||
{false, schemapb.DataType_Double, []interface{}{nil, nil}, "invalid float64"},
|
||||
{false, schemapb.DataType_VarChar, []interface{}{nil, nil}, "invalid varChar"},
|
||||
{false, schemapb.DataType_FloatVector, []interface{}{nil, nil}, "invalid floatvector"},
|
||||
{false, schemapb.DataType_BinaryVector, []interface{}{nil, nil}, "invalid binaryvector"},
|
||||
{false, schemapb.DataType_String, nil, "invalid data type"},
|
||||
{false, schemapb.DataType_None, nil, "invalid data type"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
@ -243,7 +247,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
|
|||
t.Run("Merge without expiration", func(t *testing.T) {
|
||||
Params.DataCoordCfg.CompactionEntityExpiration = math.MaxInt64
|
||||
iData := genInsertDataWithExpiredTS()
|
||||
meta := NewMetaFactory().GetCollectionMeta(1, "test")
|
||||
meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64)
|
||||
|
||||
iblobs, err := getInsertBlobs(100, iData, meta)
|
||||
require.NoError(t, err)
|
||||
|
@ -272,7 +276,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
|
|||
}()
|
||||
Params.DataNodeCfg.FlushInsertBufferSize = 128
|
||||
iData := genInsertDataWithExpiredTS()
|
||||
meta := NewMetaFactory().GetCollectionMeta(1, "test")
|
||||
meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64)
|
||||
|
||||
iblobs, err := getInsertBlobs(100, iData, meta)
|
||||
require.NoError(t, err)
|
||||
|
@ -295,7 +299,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
|
|||
t.Run("Merge with expiration", func(t *testing.T) {
|
||||
Params.DataCoordCfg.CompactionEntityExpiration = 864000 // 10 days in seconds
|
||||
iData := genInsertDataWithExpiredTS()
|
||||
meta := NewMetaFactory().GetCollectionMeta(1, "test")
|
||||
meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64)
|
||||
|
||||
iblobs, err := getInsertBlobs(100, iData, meta)
|
||||
require.NoError(t, err)
|
||||
|
@ -448,16 +452,18 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
var collID, partID, segID UniqueID = 1, 10, 100
|
||||
|
||||
alloc := NewAllocatorFactory(1)
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
dc := &DataCoordFactory{}
|
||||
mockfm := &mockFlushManager{}
|
||||
mockbIO := &binlogIO{cm, alloc}
|
||||
replica, err := newReplica(context.TODO(), rc, cm, collID)
|
||||
require.NoError(t, err)
|
||||
replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, []UniqueID{1, 2})
|
||||
replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1, 2}})
|
||||
|
||||
iData := genInsertData()
|
||||
meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name")
|
||||
meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64)
|
||||
dData := &DeleteData{
|
||||
Pks: []UniqueID{1},
|
||||
Tss: []Timestamp{20000},
|
||||
|
@ -466,7 +472,7 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
|
||||
cpaths, err := mockbIO.upload(context.TODO(), segID, partID, []*InsertData{iData}, dData, meta)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 11, len(cpaths.inPaths))
|
||||
require.Equal(t, 12, len(cpaths.inPaths))
|
||||
segBinlogs := []*datapb.CompactionSegmentBinlogs{
|
||||
{
|
||||
SegmentID: segID,
|
||||
|
@ -524,7 +530,7 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
assert.False(t, replica.hasSegment(segID, true))
|
||||
|
||||
// re-add the segment
|
||||
replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, []UniqueID{1, 2})
|
||||
replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1, 2}})
|
||||
|
||||
// Compact empty segment
|
||||
err = cm.RemoveWithPrefix("/")
|
||||
|
@ -575,7 +581,9 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
var collID, partID, segID1, segID2 UniqueID = 1, 10, 200, 201
|
||||
|
||||
alloc := NewAllocatorFactory(1)
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
dc := &DataCoordFactory{}
|
||||
mockfm := &mockFlushManager{}
|
||||
mockKv := memkv.NewMemoryKV()
|
||||
|
@ -583,12 +591,12 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
replica, err := newReplica(context.TODO(), rc, cm, collID)
|
||||
require.NoError(t, err)
|
||||
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{9})
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{9}})
|
||||
require.True(t, replica.hasSegment(segID1, true))
|
||||
require.True(t, replica.hasSegment(segID2, true))
|
||||
|
||||
meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name")
|
||||
meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64)
|
||||
iData1 := genInsertDataWithPKs([2]int64{1, 2})
|
||||
dData1 := &DeleteData{
|
||||
Pks: []UniqueID{1},
|
||||
|
@ -604,11 +612,11 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
|
||||
cpaths1, err := mockbIO.upload(context.TODO(), segID1, partID, []*InsertData{iData1}, dData1, meta)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 11, len(cpaths1.inPaths))
|
||||
require.Equal(t, 12, len(cpaths1.inPaths))
|
||||
|
||||
cpaths2, err := mockbIO.upload(context.TODO(), segID2, partID, []*InsertData{iData2}, dData2, meta)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 11, len(cpaths2.inPaths))
|
||||
require.Equal(t, 12, len(cpaths2.inPaths))
|
||||
|
||||
plan := &datapb.CompactionPlan{
|
||||
PlanID: 10080,
|
||||
|
@ -652,8 +660,8 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
plan.PlanID++
|
||||
|
||||
plan.Timetravel = Timestamp(25000)
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{9})
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{9}})
|
||||
replica.removeSegments(19530)
|
||||
require.True(t, replica.hasSegment(segID1, true))
|
||||
require.True(t, replica.hasSegment(segID2, true))
|
||||
|
@ -676,8 +684,8 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
plan.PlanID++
|
||||
|
||||
plan.Timetravel = Timestamp(10000)
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{9})
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{9}})
|
||||
replica.removeSegments(19530)
|
||||
require.True(t, replica.hasSegment(segID1, true))
|
||||
require.True(t, replica.hasSegment(segID2, true))
|
||||
|
@ -701,19 +709,21 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
var collID, partID, segID1, segID2 UniqueID = 1, 10, 200, 201
|
||||
|
||||
alloc := NewAllocatorFactory(1)
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
dc := &DataCoordFactory{}
|
||||
mockfm := &mockFlushManager{}
|
||||
mockbIO := &binlogIO{cm, alloc}
|
||||
replica, err := newReplica(context.TODO(), rc, cm, collID)
|
||||
require.NoError(t, err)
|
||||
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{1})
|
||||
replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}})
|
||||
replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}})
|
||||
require.True(t, replica.hasSegment(segID1, true))
|
||||
require.True(t, replica.hasSegment(segID2, true))
|
||||
|
||||
meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name")
|
||||
meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64)
|
||||
// the same pk for segmentI and segmentII
|
||||
iData1 := genInsertDataWithPKs([2]int64{1, 2})
|
||||
iData2 := genInsertDataWithPKs([2]int64{1, 2})
|
||||
|
@ -732,11 +742,11 @@ func TestCompactorInterfaceMethods(t *testing.T) {
|
|||
|
||||
cpaths1, err := mockbIO.upload(context.TODO(), segID1, partID, []*InsertData{iData1}, dData1, meta)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 11, len(cpaths1.inPaths))
|
||||
require.Equal(t, 12, len(cpaths1.inPaths))
|
||||
|
||||
cpaths2, err := mockbIO.upload(context.TODO(), segID2, partID, []*InsertData{iData2}, dData2, meta)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 11, len(cpaths2.inPaths))
|
||||
require.Equal(t, 12, len(cpaths2.inPaths))
|
||||
|
||||
plan := &datapb.CompactionPlan{
|
||||
PlanID: 20080,
|
||||
|
|
|
@ -38,6 +38,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
|
@ -68,7 +69,7 @@ func TestDataNode(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
node := newIDLEDataNodeMock(ctx)
|
||||
node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
|
||||
etcdCli, err := etcd.GetEtcdClient(&Params.EtcdCfg)
|
||||
assert.Nil(t, err)
|
||||
defer etcdCli.Close()
|
||||
|
@ -141,7 +142,7 @@ func TestDataNode(t *testing.T) {
|
|||
t.Run("Test FlushSegments", func(t *testing.T) {
|
||||
dmChannelName := "fake-by-dev-rootcoord-dml-channel-test-FlushSegments"
|
||||
|
||||
node1 := newIDLEDataNodeMock(context.TODO())
|
||||
node1 := newIDLEDataNodeMock(context.TODO(), schemapb.DataType_Int64)
|
||||
node1.SetEtcdClient(etcdCli)
|
||||
err = node1.Init()
|
||||
assert.Nil(t, err)
|
||||
|
@ -326,7 +327,7 @@ func TestDataNode(t *testing.T) {
|
|||
|
||||
t.Run("Test BackGroundGC", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
node := newIDLEDataNodeMock(ctx)
|
||||
node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
|
||||
|
||||
vchanNameCh := make(chan string)
|
||||
node.clearSignal = vchanNameCh
|
||||
|
@ -355,7 +356,7 @@ func TestDataNode(t *testing.T) {
|
|||
|
||||
func TestWatchChannel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
node := newIDLEDataNodeMock(ctx)
|
||||
node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
|
||||
etcdCli, err := etcd.GetEtcdClient(&Params.EtcdCfg)
|
||||
assert.Nil(t, err)
|
||||
defer etcdCli.Close()
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -34,6 +33,8 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
)
|
||||
|
||||
var dataSyncServiceTestDir = "/tmp/milvus_test/data_sync_service"
|
||||
|
@ -136,7 +137,7 @@ func TestDataSyncService_newDataSyncService(te *testing.T) {
|
|||
te.Run(test.description, func(t *testing.T) {
|
||||
df := &DataCoordFactory{}
|
||||
|
||||
replica, err := newReplica(context.Background(), &RootCoordFactory{}, cm, test.collID)
|
||||
replica, err := newReplica(context.Background(), &RootCoordFactory{pkType: schemapb.DataType_Int64}, cm, test.collID)
|
||||
assert.Nil(t, err)
|
||||
if test.replicaNil {
|
||||
replica = nil
|
||||
|
@ -183,8 +184,10 @@ func TestDataSyncService_Start(t *testing.T) {
|
|||
// init data node
|
||||
|
||||
Factory := &MetaFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1")
|
||||
mockRootCoord := &RootCoordFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64)
|
||||
mockRootCoord := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
collectionID := UniqueID(1)
|
||||
|
||||
flushChan := make(chan flushMsg, 100)
|
||||
|
|
|
@ -127,6 +127,7 @@ func newBufferData(dimension int64) (*BufferData, error) {
|
|||
|
||||
limit := Params.DataNodeCfg.FlushInsertBufferSize / (dimension * 4)
|
||||
|
||||
//TODO::xige-16 eval vec and string field
|
||||
return &BufferData{&InsertData{Data: make(map[UniqueID]storage.FieldData)}, 0, limit}, nil
|
||||
}
|
||||
|
||||
|
@ -417,8 +418,8 @@ func (ibNode *insertBufferNode) updateSegStatesInReplica(insertMsgs []*msgstream
|
|||
// 1.3 Put back into buffer
|
||||
// 1.4 Update related statistics
|
||||
func (ibNode *insertBufferNode) bufferInsertMsg(msg *msgstream.InsertMsg, endPos *internalpb.MsgPosition) error {
|
||||
if !msg.CheckAligned() {
|
||||
return errors.New("misaligned messages detected")
|
||||
if err := msg.CheckAligned(); err != nil {
|
||||
return err
|
||||
}
|
||||
currentSegID := msg.GetSegmentID()
|
||||
collectionID := msg.GetCollectionID()
|
||||
|
|
|
@ -68,8 +68,10 @@ func TestFlowGraphInsertBufferNodeCreate(t *testing.T) {
|
|||
Params.EtcdCfg.MetaRootPath = testPath
|
||||
|
||||
Factory := &MetaFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1")
|
||||
mockRootCoord := &RootCoordFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64)
|
||||
mockRootCoord := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID)
|
||||
assert.Nil(t, err)
|
||||
|
@ -154,8 +156,10 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) {
|
|||
Params.EtcdCfg.MetaRootPath = testPath
|
||||
|
||||
Factory := &MetaFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1")
|
||||
mockRootCoord := &RootCoordFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64)
|
||||
mockRootCoord := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID)
|
||||
assert.Nil(t, err)
|
||||
|
@ -349,10 +353,12 @@ func TestFlowGraphInsertBufferNode_AutoFlush(t *testing.T) {
|
|||
Params.EtcdCfg.MetaRootPath = testPath
|
||||
|
||||
Factory := &MetaFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1")
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64)
|
||||
dataFactory := NewDataFactory()
|
||||
|
||||
mockRootCoord := &RootCoordFactory{}
|
||||
mockRootCoord := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
colRep := &SegmentReplica{
|
||||
collectionID: collMeta.ID,
|
||||
|
@ -597,7 +603,7 @@ type CompactedRootCoord struct {
|
|||
}
|
||||
|
||||
func (m *CompactedRootCoord) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
if in.GetTimeStamp() <= m.compactTs {
|
||||
if in.TimeStamp != 0 && in.GetTimeStamp() <= m.compactTs {
|
||||
return &milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -620,50 +626,62 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) {
|
|||
Params.EtcdCfg.MetaRootPath = testPath
|
||||
|
||||
Factory := &MetaFactory{}
|
||||
collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1")
|
||||
|
||||
rcf := &RootCoordFactory{}
|
||||
mockRootCoord := &CompactedRootCoord{
|
||||
RootCoord: rcf,
|
||||
compactTs: 100,
|
||||
tests := []struct {
|
||||
collID UniqueID
|
||||
pkType schemapb.DataType
|
||||
description string
|
||||
}{
|
||||
{0, schemapb.DataType_Int64, "int64PrimaryData"},
|
||||
{0, schemapb.DataType_VarChar, "varCharPrimaryData"},
|
||||
}
|
||||
|
||||
cm := storage.NewLocalChunkManager(storage.RootPath(insertNodeTestDir))
|
||||
defer cm.RemoveWithPrefix("")
|
||||
replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID)
|
||||
assert.Nil(t, err)
|
||||
for _, test := range tests {
|
||||
collMeta := Factory.GetCollectionMeta(test.collID, "collection", test.pkType)
|
||||
rcf := &RootCoordFactory{
|
||||
pkType: test.pkType,
|
||||
}
|
||||
mockRootCoord := &CompactedRootCoord{
|
||||
RootCoord: rcf,
|
||||
compactTs: 100,
|
||||
}
|
||||
|
||||
err = replica.addNewSegment(1, collMeta.ID, 0, insertChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{})
|
||||
require.NoError(t, err)
|
||||
|
||||
msFactory := msgstream.NewPmsFactory()
|
||||
err = msFactory.Init(&Params)
|
||||
assert.Nil(t, err)
|
||||
|
||||
fm := NewRendezvousFlushManager(&allocator{}, cm, replica, func(*segmentFlushPack) {}, emptyFlushAndDropFunc)
|
||||
|
||||
flushChan := make(chan flushMsg, 100)
|
||||
c := &nodeConfig{
|
||||
replica: replica,
|
||||
msFactory: msFactory,
|
||||
allocator: NewAllocatorFactory(),
|
||||
vChannelName: "string",
|
||||
}
|
||||
iBNode, err := newInsertBufferNode(ctx, collMeta.ID, flushChan, fm, newCache(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
inMsg := genFlowGraphInsertMsg(insertChannelName)
|
||||
for _, msg := range inMsg.insertMessages {
|
||||
msg.EndTimestamp = 101 // ts valid
|
||||
err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{})
|
||||
replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
for _, msg := range inMsg.insertMessages {
|
||||
msg.EndTimestamp = 101 // ts valid
|
||||
msg.RowIDs = []int64{} //misaligned data
|
||||
err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{})
|
||||
assert.NotNil(t, err)
|
||||
err = replica.addNewSegment(1, collMeta.ID, 0, insertChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{Timestamp: 101})
|
||||
require.NoError(t, err)
|
||||
|
||||
msFactory := msgstream.NewPmsFactory()
|
||||
err = msFactory.Init(&Params)
|
||||
assert.Nil(t, err)
|
||||
|
||||
fm := NewRendezvousFlushManager(&allocator{}, cm, replica, func(*segmentFlushPack) {}, emptyFlushAndDropFunc)
|
||||
|
||||
flushChan := make(chan flushMsg, 100)
|
||||
c := &nodeConfig{
|
||||
replica: replica,
|
||||
msFactory: msFactory,
|
||||
allocator: NewAllocatorFactory(),
|
||||
vChannelName: "string",
|
||||
}
|
||||
iBNode, err := newInsertBufferNode(ctx, collMeta.ID, flushChan, fm, newCache(), c)
|
||||
require.NoError(t, err)
|
||||
|
||||
inMsg := genFlowGraphInsertMsg(insertChannelName)
|
||||
for _, msg := range inMsg.insertMessages {
|
||||
msg.EndTimestamp = 101 // ts valid
|
||||
err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
for _, msg := range inMsg.insertMessages {
|
||||
msg.EndTimestamp = 101 // ts valid
|
||||
msg.RowIDs = []int64{} //misaligned data
|
||||
err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -681,7 +699,7 @@ func TestInsertBufferNode_updateSegStatesInReplica(te *testing.T) {
|
|||
}
|
||||
|
||||
for _, test := range invalideTests {
|
||||
replica, err := newReplica(context.Background(), &RootCoordFactory{}, cm, test.replicaCollID)
|
||||
replica, err := newReplica(context.Background(), &RootCoordFactory{pkType: schemapb.DataType_Int64}, cm, test.replicaCollID)
|
||||
assert.Nil(te, err)
|
||||
|
||||
ibNode := &insertBufferNode{
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -36,7 +37,7 @@ func TestFlowGraphManager(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
defer etcdCli.Close()
|
||||
|
||||
node := newIDLEDataNodeMock(ctx)
|
||||
node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64)
|
||||
node.SetEtcdClient(etcdCli)
|
||||
err = node.Init()
|
||||
require.Nil(t, err)
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -516,7 +517,9 @@ func TestRendezvousFlushManager_close(t *testing.T) {
|
|||
|
||||
func TestFlushNotifyFunc(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rcf := &RootCoordFactory{}
|
||||
rcf := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
cm := storage.NewLocalChunkManager(storage.RootPath(flushTestDir))
|
||||
|
||||
replica, err := newReplica(ctx, rcf, cm, 1)
|
||||
|
@ -568,7 +571,10 @@ func TestFlushNotifyFunc(t *testing.T) {
|
|||
|
||||
func TestDropVirtualChannelFunc(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rcf := &RootCoordFactory{}
|
||||
rcf := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
cm := storage.NewLocalChunkManager(storage.RootPath(flushTestDir))
|
||||
replica, err := newReplica(ctx, rcf, cm, 1)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -37,7 +38,9 @@ func TestMetaService_All(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
mFactory := &RootCoordFactory{}
|
||||
mFactory := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
mFactory.setCollectionID(collectionID0)
|
||||
mFactory.setCollectionName(collectionName0)
|
||||
ms := newMetaService(mFactory, collectionID0)
|
||||
|
@ -52,7 +55,7 @@ func TestMetaService_All(t *testing.T) {
|
|||
|
||||
t.Run("Test printCollectionStruct", func(t *testing.T) {
|
||||
mf := &MetaFactory{}
|
||||
collectionMeta := mf.GetCollectionMeta(collectionID0, collectionName0)
|
||||
collectionMeta := mf.GetCollectionMeta(collectionID0, collectionName0, schemapb.DataType_Int64)
|
||||
printCollectionStruct(collectionMeta)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ const debug = false
|
|||
|
||||
var emptyFlushAndDropFunc flushAndDropFunc = func(_ []*segmentFlushPack) {}
|
||||
|
||||
func newIDLEDataNodeMock(ctx context.Context) *DataNode {
|
||||
func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode {
|
||||
msFactory := msgstream.NewRmsFactory()
|
||||
node := NewDataNode(ctx, msFactory)
|
||||
|
||||
|
@ -60,6 +60,7 @@ func newIDLEDataNodeMock(ctx context.Context) *DataNode {
|
|||
ID: 0,
|
||||
collectionID: 1,
|
||||
collectionName: "collection-1",
|
||||
pkType: pkType,
|
||||
}
|
||||
node.rootCoord = rc
|
||||
|
||||
|
@ -146,7 +147,8 @@ func NewMetaFactory() *MetaFactory {
|
|||
}
|
||||
|
||||
type DataFactory struct {
|
||||
rawData []byte
|
||||
rawData []byte
|
||||
columnData []*schemapb.FieldData
|
||||
}
|
||||
|
||||
type RootCoordFactory struct {
|
||||
|
@ -154,6 +156,7 @@ type RootCoordFactory struct {
|
|||
ID UniqueID
|
||||
collectionName string
|
||||
collectionID UniqueID
|
||||
pkType schemapb.DataType
|
||||
}
|
||||
|
||||
type DataCoordFactory struct {
|
||||
|
@ -209,144 +212,166 @@ func (ds *DataCoordFactory) DropVirtualChannel(ctx context.Context, req *datapb.
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (mf *MetaFactory) GetCollectionMeta(collectionID UniqueID, collectionName string) *etcdpb.CollectionMeta {
|
||||
func (mf *MetaFactory) GetCollectionMeta(collectionID UniqueID, collectionName string, pkDataType schemapb.DataType) *etcdpb.CollectionMeta {
|
||||
sch := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Description: "test collection by meta factory",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 0,
|
||||
Name: "RowID",
|
||||
Description: "RowID field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "f0_tk1",
|
||||
Value: "f0_tv1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 1,
|
||||
Name: "Timestamp",
|
||||
Description: "Timestamp field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "f1_tk1",
|
||||
Value: "f1_tv1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "float_vector_field",
|
||||
Description: "field 100",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "2",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "indexkey",
|
||||
Value: "indexvalue",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "binary_vector_field",
|
||||
Description: "field 101",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "32",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "indexkey",
|
||||
Value: "indexvalue",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "bool_field",
|
||||
Description: "field 102",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "int8_field",
|
||||
Description: "field 103",
|
||||
DataType: schemapb.DataType_Int8,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "int16_field",
|
||||
Description: "field 104",
|
||||
DataType: schemapb.DataType_Int16,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "int32_field",
|
||||
Description: "field 105",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "int64_field",
|
||||
Description: "field 106",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 107,
|
||||
Name: "float32_field",
|
||||
Description: "field 107",
|
||||
DataType: schemapb.DataType_Float,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 108,
|
||||
Name: "float64_field",
|
||||
Description: "field 108",
|
||||
DataType: schemapb.DataType_Double,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
},
|
||||
}
|
||||
sch.Fields = mf.GetFieldSchema()
|
||||
for _, field := range sch.Fields {
|
||||
if field.GetDataType() == pkDataType && field.FieldID >= 100 {
|
||||
field.IsPrimaryKey = true
|
||||
}
|
||||
}
|
||||
|
||||
collection := etcdpb.CollectionMeta{
|
||||
return &etcdpb.CollectionMeta{
|
||||
ID: collectionID,
|
||||
Schema: &sch,
|
||||
CreateTime: Timestamp(1),
|
||||
SegmentIDs: make([]UniqueID, 0),
|
||||
PartitionIDs: []UniqueID{0},
|
||||
}
|
||||
return &collection
|
||||
}
|
||||
|
||||
func (mf *MetaFactory) GetFieldSchema() []*schemapb.FieldSchema {
|
||||
fields := []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 0,
|
||||
Name: "RowID",
|
||||
Description: "RowID field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "f0_tk1",
|
||||
Value: "f0_tv1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 1,
|
||||
Name: "Timestamp",
|
||||
Description: "Timestamp field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "f1_tk1",
|
||||
Value: "f1_tv1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "float_vector_field",
|
||||
Description: "field 100",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "2",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "indexkey",
|
||||
Value: "indexvalue",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "binary_vector_field",
|
||||
Description: "field 101",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "32",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "indexkey",
|
||||
Value: "indexvalue",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "bool_field",
|
||||
Description: "field 102",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "int8_field",
|
||||
Description: "field 103",
|
||||
DataType: schemapb.DataType_Int8,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "int16_field",
|
||||
Description: "field 104",
|
||||
DataType: schemapb.DataType_Int16,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "int32_field",
|
||||
Description: "field 105",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "int64_field",
|
||||
Description: "field 106",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 107,
|
||||
Name: "float32_field",
|
||||
Description: "field 107",
|
||||
DataType: schemapb.DataType_Float,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 108,
|
||||
Name: "float64_field",
|
||||
Description: "field 108",
|
||||
DataType: schemapb.DataType_Double,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
{
|
||||
FieldID: 109,
|
||||
Name: "varChar_field",
|
||||
Description: "field 109",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "max_length_per_row",
|
||||
Value: "100",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func NewDataFactory() *DataFactory {
|
||||
return &DataFactory{rawData: GenRowData()}
|
||||
return &DataFactory{rawData: GenRowData(), columnData: GenColumnData()}
|
||||
}
|
||||
|
||||
func GenRowData() (rawData []byte) {
|
||||
|
@ -427,6 +452,192 @@ func GenRowData() (rawData []byte) {
|
|||
return
|
||||
}
|
||||
|
||||
func GenColumnData() (fieldsData []*schemapb.FieldData) {
|
||||
// Float vector
|
||||
var fVector = []float32{1, 2}
|
||||
floatVectorData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "float_vector_field",
|
||||
FieldId: 100,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 2,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: fVector,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, floatVectorData)
|
||||
|
||||
// Binary vector
|
||||
// Dimension of binary vector is 32
|
||||
// size := 4, = 32 / 8
|
||||
binaryVector := []byte{255, 255, 255, 0}
|
||||
binaryVectorData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: "binary_vector_field",
|
||||
FieldId: 101,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 32,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryVector,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, binaryVectorData)
|
||||
|
||||
// bool
|
||||
boolData := []bool{true}
|
||||
boolFieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
FieldName: "bool_field",
|
||||
FieldId: 102,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: boolData,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, boolFieldData)
|
||||
|
||||
// int8
|
||||
int8Data := []int32{100}
|
||||
int8FieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
FieldName: "int8_field",
|
||||
FieldId: 103,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: int8Data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, int8FieldData)
|
||||
|
||||
// int16
|
||||
int16Data := []int32{200}
|
||||
int16FieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int16,
|
||||
FieldName: "int16_field",
|
||||
FieldId: 104,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: int16Data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, int16FieldData)
|
||||
|
||||
// int32
|
||||
int32Data := []int32{300}
|
||||
int32FieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int32,
|
||||
FieldName: "int32_field",
|
||||
FieldId: 105,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: int32Data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, int32FieldData)
|
||||
|
||||
// int64
|
||||
int64Data := []int64{400}
|
||||
int64FieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "int64_field",
|
||||
FieldId: 106,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: int64Data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, int64FieldData)
|
||||
|
||||
// float
|
||||
floatData := []float32{1.1}
|
||||
floatFieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: "float32_field",
|
||||
FieldId: 107,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: floatData,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, floatFieldData)
|
||||
|
||||
//double
|
||||
doubleData := []float64{2.2}
|
||||
doubleFieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Double,
|
||||
FieldName: "float64_field",
|
||||
FieldId: 108,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: doubleData,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, doubleFieldData)
|
||||
|
||||
//var char
|
||||
varCharData := []string{"test"}
|
||||
varCharFieldData := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldName: "varChar_field",
|
||||
FieldId: 109,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: varCharData,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
fieldsData = append(fieldsData, varCharFieldData)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (df *DataFactory) GenMsgStreamInsertMsg(idx int, chanName string) *msgstream.InsertMsg {
|
||||
var msg = &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
|
@ -446,7 +657,10 @@ func (df *DataFactory) GenMsgStreamInsertMsg(idx int, chanName string) *msgstrea
|
|||
ShardName: chanName,
|
||||
Timestamps: []Timestamp{Timestamp(idx + 1000)},
|
||||
RowIDs: []UniqueID{UniqueID(idx)},
|
||||
RowData: []*commonpb.Blob{{Value: df.rawData}},
|
||||
// RowData: []*commonpb.Blob{{Value: df.rawData}},
|
||||
FieldsData: df.columnData,
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
NumRows: 1,
|
||||
},
|
||||
}
|
||||
return msg
|
||||
|
@ -662,7 +876,7 @@ func (m *RootCoordFactory) ShowCollections(ctx context.Context, in *milvuspb.Sho
|
|||
|
||||
func (m *RootCoordFactory) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
f := MetaFactory{}
|
||||
meta := f.GetCollectionMeta(m.collectionID, m.collectionName)
|
||||
meta := f.GetCollectionMeta(m.collectionID, m.collectionName, m.pkType)
|
||||
resp := &milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
@ -764,6 +978,10 @@ func genInsertData() *InsertData {
|
|||
NumRows: []int64{2},
|
||||
Data: []float64{3.333, 3.334},
|
||||
},
|
||||
109: &s.StringFieldData{
|
||||
NumRows: []int64{2},
|
||||
Data: []string{"test1", "test2"},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
|
@ -816,6 +1034,10 @@ func genEmptyInsertData() *InsertData {
|
|||
NumRows: []int64{0},
|
||||
Data: []float64{},
|
||||
},
|
||||
109: &s.StringFieldData{
|
||||
NumRows: []int64{0},
|
||||
Data: []string{},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
|
@ -868,6 +1090,10 @@ func genInsertDataWithExpiredTS() *InsertData {
|
|||
NumRows: []int64{2},
|
||||
Data: []float64{3.333, 3.334},
|
||||
},
|
||||
109: &s.StringFieldData{
|
||||
NumRows: []int64{2},
|
||||
Data: []string{"test1", "test2"},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ package datanode
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
|
@ -42,6 +41,10 @@ const (
|
|||
maxBloomFalsePositive float64 = 0.005
|
||||
)
|
||||
|
||||
type PrimaryKey = storage.PrimaryKey
|
||||
type Int64PrimaryKey = storage.Int64PrimaryKey
|
||||
type StringPrimaryKey = storage.StringPrimaryKey
|
||||
|
||||
// Replica is DataNode unique replication
|
||||
type Replica interface {
|
||||
getCollectionID() UniqueID
|
||||
|
@ -57,8 +60,8 @@ type Replica interface {
|
|||
listSegmentsCheckPoints() map[UniqueID]segmentCheckPoint
|
||||
updateSegmentEndPosition(segID UniqueID, endPos *internalpb.MsgPosition)
|
||||
updateSegmentCheckPoint(segID UniqueID)
|
||||
updateSegmentPKRange(segID UniqueID, rowIDs []int64)
|
||||
mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64)
|
||||
updateSegmentPKRange(segID UniqueID, ids storage.FieldData)
|
||||
mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64) error
|
||||
hasSegment(segID UniqueID, countFlushed bool) bool
|
||||
removeSegments(segID ...UniqueID)
|
||||
listCompactedSegmentIDs() map[UniqueID][]UniqueID
|
||||
|
@ -87,8 +90,8 @@ type Segment struct {
|
|||
|
||||
pkFilter *bloom.BloomFilter // bloom filter of pk inside a segment
|
||||
// TODO silverxia, needs to change to interface to support `string` type PK
|
||||
minPK int64 // minimal pk value, shortcut for checking whether a pk is inside this segment
|
||||
maxPK int64 // maximal pk value, same above
|
||||
minPK PrimaryKey // minimal pk value, shortcut for checking whether a pk is inside this segment
|
||||
maxPK PrimaryKey // maximal pk value, same above
|
||||
}
|
||||
|
||||
// SegmentReplica is the data replication of persistent data in datanode.
|
||||
|
@ -107,23 +110,58 @@ type SegmentReplica struct {
|
|||
chunkManager storage.ChunkManager
|
||||
}
|
||||
|
||||
func (s *Segment) updatePKRange(pks []int64) {
|
||||
buf := make([]byte, 8)
|
||||
for _, pk := range pks {
|
||||
common.Endian.PutUint64(buf, uint64(pk))
|
||||
s.pkFilter.Add(buf)
|
||||
if pk > s.maxPK {
|
||||
s.maxPK = pk
|
||||
func (s *Segment) updatePk(pk PrimaryKey) error {
|
||||
if s.minPK == nil {
|
||||
s.minPK = pk
|
||||
} else if s.minPK.GT(pk) {
|
||||
s.minPK = pk
|
||||
}
|
||||
|
||||
if s.maxPK == nil {
|
||||
s.maxPK = pk
|
||||
} else if s.maxPK.LT(pk) {
|
||||
s.maxPK = pk
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Segment) updatePKRange(ids storage.FieldData) error {
|
||||
switch pks := ids.(type) {
|
||||
case *storage.Int64FieldData:
|
||||
buf := make([]byte, 8)
|
||||
for _, pk := range pks.Data {
|
||||
id := &Int64PrimaryKey{
|
||||
Value: pk,
|
||||
}
|
||||
err := s.updatePk(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.Endian.PutUint64(buf, uint64(pk))
|
||||
s.pkFilter.Add(buf)
|
||||
}
|
||||
if pk < s.minPK {
|
||||
s.minPK = pk
|
||||
case *storage.StringFieldData:
|
||||
for _, pk := range pks.Data {
|
||||
id := &StringPrimaryKey{
|
||||
Value: pk,
|
||||
}
|
||||
err := s.updatePk(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.pkFilter.Add([]byte(pk))
|
||||
}
|
||||
default:
|
||||
//TODO::
|
||||
}
|
||||
|
||||
log.Info("update pk range",
|
||||
zap.Int64("collectionID", s.collectionID), zap.Int64("partitionID", s.partitionID), zap.Int64("segmentID", s.segmentID),
|
||||
zap.String("channel", s.channelName),
|
||||
zap.Int64("num_rows", s.numRows), zap.Int64("minPK", s.minPK), zap.Int64("maxPK", s.maxPK))
|
||||
zap.Int64("num_rows", s.numRows), zap.Any("minPK", s.minPK), zap.Any("maxPK", s.maxPK))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Replica = &SegmentReplica{}
|
||||
|
@ -241,8 +279,6 @@ func (replica *SegmentReplica) addNewSegment(segID, collID, partitionID UniqueID
|
|||
endPos: endPos,
|
||||
|
||||
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
|
||||
minPK: math.MaxInt64, // use max value, represents no value
|
||||
maxPK: math.MinInt64, // use min value represents no value
|
||||
}
|
||||
|
||||
seg.isNew.Store(true)
|
||||
|
@ -328,9 +364,8 @@ func (replica *SegmentReplica) addNormalSegment(segID, collID, partitionID Uniqu
|
|||
numRows: numOfRows,
|
||||
|
||||
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
|
||||
minPK: math.MaxInt64, // use max value, represents no value
|
||||
maxPK: math.MinInt64, // use min value represents no value
|
||||
}
|
||||
|
||||
if cp != nil {
|
||||
seg.checkPoint = *cp
|
||||
seg.endPos = &cp.pos
|
||||
|
@ -378,8 +413,6 @@ func (replica *SegmentReplica) addFlushedSegment(segID, collID, partitionID Uniq
|
|||
|
||||
//TODO silverxia, normal segments bloom filter and pk range should be loaded from serialized files
|
||||
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
|
||||
minPK: math.MaxInt64, // use max value, represents no value
|
||||
maxPK: math.MinInt64, // use min value represents no value
|
||||
}
|
||||
|
||||
err := replica.initPKBloomFilter(seg, statsBinlogs, recoverTs)
|
||||
|
@ -444,13 +477,8 @@ func (replica *SegmentReplica) initPKBloomFilter(s *Segment, statsBinlogs []*dat
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.minPK > stat.Min {
|
||||
s.minPK = stat.Min
|
||||
}
|
||||
|
||||
if s.maxPK < stat.Max {
|
||||
s.maxPK = stat.Max
|
||||
}
|
||||
s.updatePk(stat.MinPk)
|
||||
s.updatePk(stat.MaxPk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -513,25 +541,25 @@ func (replica *SegmentReplica) updateSegmentEndPosition(segID UniqueID, endPos *
|
|||
log.Warn("No match segment", zap.Int64("ID", segID))
|
||||
}
|
||||
|
||||
func (replica *SegmentReplica) updateSegmentPKRange(segID UniqueID, pks []int64) {
|
||||
func (replica *SegmentReplica) updateSegmentPKRange(segID UniqueID, ids storage.FieldData) {
|
||||
replica.segMu.Lock()
|
||||
defer replica.segMu.Unlock()
|
||||
|
||||
seg, ok := replica.newSegments[segID]
|
||||
if ok {
|
||||
seg.updatePKRange(pks)
|
||||
seg.updatePKRange(ids)
|
||||
return
|
||||
}
|
||||
|
||||
seg, ok = replica.normalSegments[segID]
|
||||
if ok {
|
||||
seg.updatePKRange(pks)
|
||||
seg.updatePKRange(ids)
|
||||
return
|
||||
}
|
||||
|
||||
seg, ok = replica.flushedSegments[segID]
|
||||
if ok {
|
||||
seg.updatePKRange(pks)
|
||||
seg.updatePKRange(ids)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -684,12 +712,12 @@ func (replica *SegmentReplica) updateSegmentCheckPoint(segID UniqueID) {
|
|||
log.Warn("There's no segment", zap.Int64("ID", segID))
|
||||
}
|
||||
|
||||
func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64) {
|
||||
func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64) error {
|
||||
if collID != replica.collectionID {
|
||||
log.Warn("Mismatch collection",
|
||||
zap.Int64("input ID", collID),
|
||||
zap.Int64("expected ID", replica.collectionID))
|
||||
return
|
||||
return fmt.Errorf("mismatch collection, ID=%d", collID)
|
||||
}
|
||||
|
||||
log.Info("merge flushed segments",
|
||||
|
@ -708,8 +736,6 @@ func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planI
|
|||
numRows: numOfRows,
|
||||
|
||||
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
|
||||
minPK: math.MaxInt64, // use max value, represents no value
|
||||
maxPK: math.MinInt64, // use min value represents no value
|
||||
}
|
||||
|
||||
replica.segMu.Lock()
|
||||
|
@ -735,15 +761,17 @@ func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planI
|
|||
replica.segMu.Lock()
|
||||
replica.flushedSegments[segID] = seg
|
||||
replica.segMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// for tests only
|
||||
func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID UniqueID, channelName string, numOfRows int64, pks []UniqueID) {
|
||||
func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID UniqueID, channelName string, numOfRows int64, ids storage.FieldData) error {
|
||||
if collID != replica.collectionID {
|
||||
log.Warn("Mismatch collection",
|
||||
zap.Int64("input ID", collID),
|
||||
zap.Int64("expected ID", replica.collectionID))
|
||||
return
|
||||
return fmt.Errorf("mismatch collection, ID=%d", collID)
|
||||
}
|
||||
|
||||
log.Info("Add Flushed segment",
|
||||
|
@ -761,11 +789,9 @@ func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID Un
|
|||
numRows: numOfRows,
|
||||
|
||||
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
|
||||
minPK: math.MaxInt64, // use max value, represents no value
|
||||
maxPK: math.MinInt64, // use min value represents no value
|
||||
}
|
||||
|
||||
seg.updatePKRange(pks)
|
||||
seg.updatePKRange(ids)
|
||||
|
||||
seg.isNew.Store(false)
|
||||
seg.isFlushed.Store(true)
|
||||
|
@ -773,6 +799,8 @@ func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID Un
|
|||
replica.segMu.Lock()
|
||||
replica.flushedSegments[segID] = seg
|
||||
replica.segMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (replica *SegmentReplica) listAllSegmentIDs() []UniqueID {
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
|
@ -31,6 +30,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
)
|
||||
|
||||
|
@ -50,7 +50,7 @@ type mockDataCM struct {
|
|||
}
|
||||
|
||||
func (kv *mockDataCM) MultiRead(keys []string) ([][]byte, error) {
|
||||
stats := &storage.Int64Stats{
|
||||
stats := &storage.PrimaryKeyStats{
|
||||
FieldID: common.RowIDField,
|
||||
Min: 0,
|
||||
Max: 10,
|
||||
|
@ -65,7 +65,7 @@ type mockPkfilterMergeError struct {
|
|||
}
|
||||
|
||||
func (kv *mockPkfilterMergeError) MultiRead(keys []string) ([][]byte, error) {
|
||||
stats := &storage.Int64Stats{
|
||||
stats := &storage.PrimaryKeyStats{
|
||||
FieldID: common.RowIDField,
|
||||
Min: 0,
|
||||
Max: 10,
|
||||
|
@ -170,7 +170,9 @@ func TestSegmentReplica_getCollectionAndPartitionID(te *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegmentReplica(t *testing.T) {
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
collID := UniqueID(1)
|
||||
cm := storage.NewLocalChunkManager(storage.RootPath(segmentReplicaNodeTestDir))
|
||||
defer cm.RemoveWithPrefix("")
|
||||
|
@ -252,7 +254,9 @@ func TestSegmentReplica(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegmentReplica_InterfaceMethod(t *testing.T) {
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
cm := storage.NewLocalChunkManager(storage.RootPath(segmentReplicaNodeTestDir))
|
||||
defer cm.RemoveWithPrefix("")
|
||||
|
||||
|
@ -268,17 +272,20 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) {
|
|||
{false, 1, 2, "invalid input collection with replica collection"},
|
||||
}
|
||||
|
||||
primaryKeyData := &storage.Int64FieldData{
|
||||
Data: []int64{9},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
replica, err := newReplica(context.TODO(), rc, cm, test.replicaCollID)
|
||||
require.NoError(t, err)
|
||||
if test.isvalid {
|
||||
replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, []int64{9})
|
||||
replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, primaryKeyData)
|
||||
|
||||
assert.True(t, replica.hasSegment(100, true))
|
||||
assert.False(t, replica.hasSegment(100, false))
|
||||
} else {
|
||||
replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, []int64{9})
|
||||
replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, primaryKeyData)
|
||||
assert.False(t, replica.hasSegment(100, true))
|
||||
assert.False(t, replica.hasSegment(100, false))
|
||||
}
|
||||
|
@ -572,6 +579,7 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
rc.setCollectionID(1)
|
||||
})
|
||||
|
||||
t.Run("Test listAllSegmentIDs", func(t *testing.T) {
|
||||
|
@ -628,8 +636,11 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) {
|
|||
sr, err := newReplica(context.Background(), rc, cm, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
sr.addFlushedSegmentWithPKs(1, 1, 0, "channel", 10, []UniqueID{1})
|
||||
sr.addFlushedSegmentWithPKs(2, 1, 0, "channel", 10, []UniqueID{1})
|
||||
primaryKeyData := &storage.Int64FieldData{
|
||||
Data: []UniqueID{1},
|
||||
}
|
||||
sr.addFlushedSegmentWithPKs(1, 1, 0, "channel", 10, primaryKeyData)
|
||||
sr.addFlushedSegmentWithPKs(2, 1, 0, "channel", 10, primaryKeyData)
|
||||
require.True(t, sr.hasSegment(1, true))
|
||||
require.True(t, sr.hasSegment(2, true))
|
||||
|
||||
|
@ -648,7 +659,9 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) {
|
|||
|
||||
}
|
||||
func TestInnerFunctionSegment(t *testing.T) {
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
collID := UniqueID(1)
|
||||
cm := storage.NewLocalChunkManager(storage.RootPath(segmentReplicaNodeTestDir))
|
||||
defer cm.RemoveWithPrefix("")
|
||||
|
@ -747,8 +760,6 @@ func TestInnerFunctionSegment(t *testing.T) {
|
|||
func TestSegmentReplica_UpdatePKRange(t *testing.T) {
|
||||
seg := &Segment{
|
||||
pkFilter: bloom.NewWithEstimates(100000, 0.005),
|
||||
maxPK: math.MinInt64,
|
||||
minPK: math.MaxInt64,
|
||||
}
|
||||
|
||||
cases := make([]int64, 0, 100)
|
||||
|
@ -757,10 +768,16 @@ func TestSegmentReplica_UpdatePKRange(t *testing.T) {
|
|||
}
|
||||
buf := make([]byte, 8)
|
||||
for _, c := range cases {
|
||||
seg.updatePKRange([]int64{c})
|
||||
seg.updatePKRange(&storage.Int64FieldData{
|
||||
Data: []int64{c},
|
||||
})
|
||||
|
||||
assert.LessOrEqual(t, seg.minPK, c)
|
||||
assert.GreaterOrEqual(t, seg.maxPK, c)
|
||||
pk := &Int64PrimaryKey{
|
||||
Value: c,
|
||||
}
|
||||
|
||||
assert.Equal(t, true, seg.minPK.LE(pk))
|
||||
assert.Equal(t, true, seg.maxPK.GE(pk))
|
||||
|
||||
common.Endian.PutUint64(buf, uint64(c))
|
||||
assert.True(t, seg.pkFilter.Test(buf))
|
||||
|
@ -768,7 +785,9 @@ func TestSegmentReplica_UpdatePKRange(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReplica_UpdatePKRange(t *testing.T) {
|
||||
rc := &RootCoordFactory{}
|
||||
rc := &RootCoordFactory{
|
||||
pkType: schemapb.DataType_Int64,
|
||||
}
|
||||
collID := UniqueID(1)
|
||||
partID := UniqueID(2)
|
||||
chanName := "insert-02"
|
||||
|
@ -797,14 +816,19 @@ func TestReplica_UpdatePKRange(t *testing.T) {
|
|||
}
|
||||
buf := make([]byte, 8)
|
||||
for _, c := range cases {
|
||||
replica.updateSegmentPKRange(1, []int64{c}) // new segment
|
||||
replica.updateSegmentPKRange(2, []int64{c}) // normal segment
|
||||
replica.updateSegmentPKRange(3, []int64{c}) // non-exist segment
|
||||
replica.updateSegmentPKRange(1, &storage.Int64FieldData{Data: []int64{c}}) // new segment
|
||||
replica.updateSegmentPKRange(2, &storage.Int64FieldData{Data: []int64{c}}) // normal segment
|
||||
replica.updateSegmentPKRange(3, &storage.Int64FieldData{Data: []int64{c}}) // non-exist segment
|
||||
|
||||
assert.LessOrEqual(t, segNew.minPK, c)
|
||||
assert.GreaterOrEqual(t, segNew.maxPK, c)
|
||||
assert.LessOrEqual(t, segNormal.minPK, c)
|
||||
assert.GreaterOrEqual(t, segNormal.maxPK, c)
|
||||
pk := &Int64PrimaryKey{
|
||||
Value: c,
|
||||
}
|
||||
|
||||
assert.Equal(t, true, segNew.minPK.LE(pk))
|
||||
assert.Equal(t, true, segNew.maxPK.GE(pk))
|
||||
|
||||
assert.Equal(t, true, segNormal.minPK.LE(pk))
|
||||
assert.Equal(t, true, segNormal.maxPK.GE(pk))
|
||||
|
||||
common.Endian.PutUint64(buf, uint64(c))
|
||||
assert.True(t, segNew.pkFilter.Test(buf))
|
||||
|
|
|
@ -19,9 +19,11 @@ package msgstream
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -189,9 +191,32 @@ func (it *InsertMsg) NRows() uint64 {
|
|||
return it.InsertRequest.GetNumRows()
|
||||
}
|
||||
|
||||
func (it *InsertMsg) CheckAligned() bool {
|
||||
return len(it.GetRowIDs()) == len(it.GetTimestamps()) &&
|
||||
uint64(len(it.GetRowIDs())) == it.NRows()
|
||||
func (it *InsertMsg) CheckAligned() error {
|
||||
numRowsOfFieldDataMismatch := func(fieldName string, fieldNumRows, passedNumRows uint64) error {
|
||||
return fmt.Errorf("the num_rows(%d) of %sth field is not equal to passed NumRows(%d)", fieldNumRows, fieldName, passedNumRows)
|
||||
}
|
||||
rowNums := it.NRows()
|
||||
if it.IsColumnBased() {
|
||||
for _, field := range it.FieldsData {
|
||||
fieldNumRows, err := funcutil.GetNumRowOfFieldData(field)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fieldNumRows != rowNums {
|
||||
return numRowsOfFieldDataMismatch(field.FieldName, fieldNumRows, rowNums)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(it.GetRowIDs()) != len(it.GetTimestamps()) {
|
||||
return fmt.Errorf("the num_rows(%d) of rowIDs is not equal to the num_rows(%d) of timestamps", len(it.GetRowIDs()), len(it.GetTimestamps()))
|
||||
}
|
||||
|
||||
if uint64(len(it.GetRowIDs())) != it.NRows() {
|
||||
return fmt.Errorf("the num_rows(%d) of rowIDs is not equal to passed NumRows(%d)", len(it.GetRowIDs()), it.NRows())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *InsertMsg) rowBasedIndexRequest(index int) internalpb.InsertRequest {
|
||||
|
|
|
@ -191,14 +191,26 @@ func TestInsertMsg_CheckAligned(t *testing.T) {
|
|||
Version: internalpb.InsertDataVersion_RowBased,
|
||||
},
|
||||
}
|
||||
assert.True(t, msg1.CheckAligned())
|
||||
msg1.InsertRequest.NumRows = 1
|
||||
assert.NoError(t, msg1.CheckAligned())
|
||||
msg1.InsertRequest.RowData = nil
|
||||
msg1.InsertRequest.FieldsData = []*schemapb.FieldData{
|
||||
{},
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
msg1.InsertRequest.NumRows = 1
|
||||
|
||||
msg1.Version = internalpb.InsertDataVersion_ColumnBased
|
||||
assert.True(t, msg1.CheckAligned())
|
||||
assert.NoError(t, msg1.CheckAligned())
|
||||
}
|
||||
|
||||
func TestInsertMsg_IndexMsg(t *testing.T) {
|
||||
|
|
|
@ -37,7 +37,10 @@ func InsertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
|
|||
|
||||
keysLen := len(keys)
|
||||
|
||||
if !insertRequest.CheckAligned() || insertRequest.NRows() != uint64(keysLen) {
|
||||
if err := insertRequest.CheckAligned(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if insertRequest.NRows() != uint64(keysLen) {
|
||||
return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal")
|
||||
}
|
||||
for index, key := range keys {
|
||||
|
|
|
@ -2177,18 +2177,22 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
it := &insertTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
req: request,
|
||||
// req: request,
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: request.HashKeys,
|
||||
},
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 0,
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 0,
|
||||
SourceID: Params.ProxyCfg.ProxyID,
|
||||
},
|
||||
CollectionName: request.CollectionName,
|
||||
PartitionName: request.PartitionName,
|
||||
FieldsData: request.FieldsData,
|
||||
NumRows: uint64(request.NumRows),
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
// RowData: transfer column based request to this
|
||||
},
|
||||
},
|
||||
|
@ -2203,7 +2207,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
}
|
||||
|
||||
constructFailedResponse := func(err error) *milvuspb.MutationResult {
|
||||
numRows := it.req.NumRows
|
||||
numRows := request.NumRows
|
||||
errIndex := make([]uint32, numRows)
|
||||
for i := uint32(0); i < numRows; i++ {
|
||||
errIndex[i] = i
|
||||
|
@ -2256,7 +2260,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
|
||||
if it.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
setErrorIndex := func() {
|
||||
numRows := it.req.NumRows
|
||||
numRows := request.NumRows
|
||||
errIndex := make([]uint32, numRows)
|
||||
for i := uint32(0); i < numRows; i++ {
|
||||
errIndex[i] = i
|
||||
|
@ -2268,7 +2272,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
}
|
||||
|
||||
// InsertCnt always equals to the number of entities in the request
|
||||
it.result.InsertCnt = int64(it.req.NumRows)
|
||||
it.result.InsertCnt = int64(request.NumRows)
|
||||
|
||||
metrics.ProxyInsertCount.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
|
||||
metrics.SuccessLabel).Inc()
|
||||
|
|
|
@ -395,6 +395,99 @@ func newSimpleMockMsgStreamFactory() *simpleMockMsgStreamFactory {
|
|||
return &simpleMockMsgStreamFactory{}
|
||||
}
|
||||
|
||||
func generateFieldData(dataType schemapb.DataType, fieldName string, fieldID int64, numRows int) *schemapb.FieldData {
|
||||
fieldData := &schemapb.FieldData{
|
||||
Type: dataType,
|
||||
FieldName: fieldName,
|
||||
FieldId: fieldID,
|
||||
}
|
||||
switch dataType {
|
||||
case schemapb.DataType_Bool:
|
||||
fieldData.FieldName = testBoolField
|
||||
fieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: generateBoolArray(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int32:
|
||||
fieldData.FieldName = testInt32Field
|
||||
fieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: generateInt32Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
fieldData.FieldName = testInt64Field
|
||||
fieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: generateInt64Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
fieldData.FieldName = testFloatField
|
||||
fieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: generateFloat32Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
fieldData.FieldName = testDoubleField
|
||||
fieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: generateFloat64Array(numRows),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
//TODO::
|
||||
case schemapb.DataType_FloatVector:
|
||||
fieldData.FieldName = testFloatVecField
|
||||
fieldData.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(testVecDim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: generateFloatVectors(numRows, testVecDim),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
fieldData.FieldName = testBinaryVecField
|
||||
fieldData.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(testVecDim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: generateBinaryVectors(numRows, testVecDim),
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
//TODO::
|
||||
}
|
||||
|
||||
return fieldData
|
||||
}
|
||||
|
||||
func generateBoolArray(numRows int) []bool {
|
||||
ret := make([]bool, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
|
@ -435,6 +528,14 @@ func generateInt64Array(numRows int) []int64 {
|
|||
return ret
|
||||
}
|
||||
|
||||
func generateUint64Array(numRows int) []uint64 {
|
||||
ret := make([]uint64, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret = append(ret, rand.Uint64())
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateFloat32Array(numRows int) []float32 {
|
||||
ret := make([]float32, 0, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
|
@ -470,14 +571,23 @@ func generateBinaryVectors(numRows, dim int) []byte {
|
|||
return ret
|
||||
}
|
||||
|
||||
func newScalarFieldData(dType schemapb.DataType, fieldName string, numRows int) *schemapb.FieldData {
|
||||
func generateVarCharArray(numRows int, maxLen int) []string {
|
||||
ret := make([]string, numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
ret[i] = funcutil.RandomString(rand.Intn(maxLen))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func newScalarFieldData(fieldSchema *schemapb.FieldSchema, fieldName string, numRows int) *schemapb.FieldData {
|
||||
ret := &schemapb.FieldData{
|
||||
Type: dType,
|
||||
Type: fieldSchema.DataType,
|
||||
FieldName: fieldName,
|
||||
Field: nil,
|
||||
}
|
||||
|
||||
switch dType {
|
||||
switch fieldSchema.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
|
@ -548,6 +658,16 @@ func newScalarFieldData(dType schemapb.DataType, fieldName string, numRows int)
|
|||
},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
ret.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: generateVarCharArray(numRows, testMaxVarCharLength),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
|
|
|
@ -21,13 +21,10 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
@ -122,7 +119,7 @@ type BaseInsertTask = msgstream.InsertMsg
|
|||
|
||||
type insertTask struct {
|
||||
BaseInsertTask
|
||||
req *milvuspb.InsertRequest
|
||||
// req *milvuspb.InsertRequest
|
||||
Condition
|
||||
ctx context.Context
|
||||
|
||||
|
@ -208,40 +205,9 @@ func (it *insertTask) getChannels() ([]pChan, error) {
|
|||
}
|
||||
|
||||
func (it *insertTask) OnEnqueue() error {
|
||||
it.BaseInsertTask.InsertRequest.Base = &commonpb.MsgBase{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getNumRowsOfScalarField(datas interface{}) uint32 {
|
||||
realTypeDatas := reflect.ValueOf(datas)
|
||||
return uint32(realTypeDatas.Len())
|
||||
}
|
||||
|
||||
func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint32, error) {
|
||||
if dim <= 0 {
|
||||
return 0, errDimLessThanOrEqualToZero(int(dim))
|
||||
}
|
||||
l := len(fDatas)
|
||||
if int64(l)%dim != 0 {
|
||||
return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim)
|
||||
}
|
||||
return uint32(int(int64(l) / dim)), nil
|
||||
}
|
||||
|
||||
func getNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint32, error) {
|
||||
if dim <= 0 {
|
||||
return 0, errDimLessThanOrEqualToZero(int(dim))
|
||||
}
|
||||
if dim%8 != 0 {
|
||||
return 0, errDimShouldDivide8(int(dim))
|
||||
}
|
||||
l := len(bDatas)
|
||||
if (8*int64(l))%dim != 0 {
|
||||
return 0, fmt.Errorf("the num(%d) of all bits should divide the dim(%d)", 8*l, dim)
|
||||
}
|
||||
return uint32(int((8 * int64(l)) / dim)), nil
|
||||
}
|
||||
|
||||
func (it *insertTask) checkLengthOfFieldsData() error {
|
||||
neededFieldsNum := 0
|
||||
for _, field := range it.schema.Fields {
|
||||
|
@ -250,236 +216,62 @@ func (it *insertTask) checkLengthOfFieldsData() error {
|
|||
}
|
||||
}
|
||||
|
||||
if len(it.req.FieldsData) < neededFieldsNum {
|
||||
return errFieldsLessThanNeeded(len(it.req.FieldsData), neededFieldsNum)
|
||||
if len(it.FieldsData) < neededFieldsNum {
|
||||
return errFieldsLessThanNeeded(len(it.FieldsData), neededFieldsNum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) checkRowNums() error {
|
||||
if it.req.NumRows <= 0 {
|
||||
return errNumRowsLessThanOrEqualToZero(it.req.NumRows)
|
||||
}
|
||||
|
||||
if err := it.checkLengthOfFieldsData(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowNums := it.req.NumRows
|
||||
|
||||
for i, field := range it.req.FieldsData {
|
||||
switch field.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := field.GetScalars()
|
||||
switch scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetBoolData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_IntData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetIntData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_LongData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetLongData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetFloatData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
fieldNumRows := getNumRowsOfScalarField(scalarField.GetDoubleData().Data)
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.ScalarField_BytesData:
|
||||
return errUnsupportedDType("bytes")
|
||||
case *schemapb.ScalarField_StringData:
|
||||
return errUnsupportedDType("string")
|
||||
case nil:
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
vectorField := field.GetVectors()
|
||||
switch vectorField.Data.(type) {
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
dim := vectorField.GetDim()
|
||||
fieldNumRows, err := getNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
dim := vectorField.GetDim()
|
||||
fieldNumRows, err := getNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fieldNumRows != rowNums {
|
||||
return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums)
|
||||
}
|
||||
case nil:
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) checkFieldAutoIDAndHashPK() error {
|
||||
func (it *insertTask) checkPrimaryFieldData() error {
|
||||
rowNums := uint32(it.NRows())
|
||||
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
||||
if it.req.NumRows <= 0 {
|
||||
return errNumRowsLessThanOrEqualToZero(it.req.NumRows)
|
||||
if it.NRows() <= 0 {
|
||||
return errNumRowsLessThanOrEqualToZero(rowNums)
|
||||
}
|
||||
|
||||
if err := it.checkLengthOfFieldsData(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowNums := it.req.NumRows
|
||||
|
||||
primaryFieldName := ""
|
||||
autoIDFieldName := ""
|
||||
autoIDLoc := -1
|
||||
primaryLoc := -1
|
||||
fields := it.schema.Fields
|
||||
|
||||
for loc, field := range fields {
|
||||
if field.AutoID {
|
||||
autoIDLoc = loc
|
||||
autoIDFieldName = field.Name
|
||||
}
|
||||
if field.IsPrimaryKey {
|
||||
primaryLoc = loc
|
||||
primaryFieldName = field.Name
|
||||
}
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema)
|
||||
if err != nil {
|
||||
log.Error("get primary field schema failed", zap.String("collection name", it.CollectionName), zap.Any("schema", it.schema), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if primaryLoc < 0 {
|
||||
return fmt.Errorf("primary field is not found")
|
||||
}
|
||||
|
||||
if autoIDLoc >= 0 && autoIDLoc != primaryLoc {
|
||||
return fmt.Errorf("currently auto id field is only supported on primary field")
|
||||
}
|
||||
|
||||
var primaryField *schemapb.FieldData
|
||||
var primaryData []int64
|
||||
for _, field := range it.req.FieldsData {
|
||||
if field.FieldName == autoIDFieldName {
|
||||
return fmt.Errorf("autoID field (%v) does not require data", autoIDFieldName)
|
||||
// get primaryFieldData whether autoID is true or not
|
||||
var primaryFieldData *schemapb.FieldData
|
||||
if !primaryFieldSchema.AutoID {
|
||||
primaryFieldData, err = getPrimaryFieldData(it.GetFieldsData(), primaryFieldSchema)
|
||||
if err != nil {
|
||||
log.Error("get primary field data failed", zap.String("collection name", it.CollectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if field.FieldName == primaryFieldName {
|
||||
primaryField = field
|
||||
}
|
||||
}
|
||||
|
||||
if primaryField != nil {
|
||||
if primaryField.Type != schemapb.DataType_Int64 {
|
||||
return fmt.Errorf("currently only support DataType Int64 as PrimaryField and Enable autoID")
|
||||
}
|
||||
switch primaryField.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := primaryField.GetScalars()
|
||||
switch scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_LongData:
|
||||
primaryData = scalarField.GetLongData().Data
|
||||
default:
|
||||
return fmt.Errorf("currently only support DataType Int64 as PrimaryField and Enable autoID")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("currently only support DataType Int64 as PrimaryField and Enable autoID")
|
||||
}
|
||||
it.result.IDs.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: primaryData,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var rowIDBegin UniqueID
|
||||
var rowIDEnd UniqueID
|
||||
|
||||
tr := timerecord.NewTimeRecorder("applyPK")
|
||||
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(rowNums)
|
||||
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(tr.ElapseSpan()))
|
||||
|
||||
it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
offset := i - rowIDBegin
|
||||
it.BaseInsertTask.RowIDs[offset] = i
|
||||
}
|
||||
|
||||
if autoIDLoc >= 0 {
|
||||
fieldData := schemapb.FieldData{
|
||||
FieldName: primaryFieldName,
|
||||
FieldId: -1,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: it.BaseInsertTask.RowIDs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TODO(dragondriver): when we can ignore the order of input fields, use append directly
|
||||
// it.req.FieldsData = append(it.req.FieldsData, &fieldData)
|
||||
it.req.FieldsData = append(it.req.FieldsData, &schemapb.FieldData{})
|
||||
copy(it.req.FieldsData[autoIDLoc+1:], it.req.FieldsData[autoIDLoc:])
|
||||
it.req.FieldsData[autoIDLoc] = &fieldData
|
||||
|
||||
it.result.IDs.IdField = &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: it.BaseInsertTask.RowIDs,
|
||||
},
|
||||
}
|
||||
it.HashPK(it.BaseInsertTask.RowIDs)
|
||||
} else {
|
||||
it.HashPK(primaryData)
|
||||
// if autoID == true, currently only support autoID for int64 PrimaryField
|
||||
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, it.RowIDs)
|
||||
if err != nil {
|
||||
log.Error("generate primary field data failed when autoID == true", zap.String("collection name", it.CollectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
// if autoID == true, set the primary field data
|
||||
it.FieldsData = append(it.FieldsData, primaryFieldData)
|
||||
}
|
||||
|
||||
sliceIndex := make([]uint32, rowNums)
|
||||
for i := uint32(0); i < rowNums; i++ {
|
||||
sliceIndex[i] = i
|
||||
// parse primaryFieldData to result.IDs, and as returned primary keys
|
||||
it.result.IDs, err = parsePrimaryFieldData2IDs(primaryFieldData)
|
||||
if err != nil {
|
||||
log.Error("parse primary field data to IDs failed", zap.String("collection name", it.CollectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.result.SuccIndex = sliceIndex
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) HashPK(pks []int64) {
|
||||
if len(it.HashValues) != 0 {
|
||||
log.Warn("the hashvalues passed through client is not supported now, and will be overwritten")
|
||||
}
|
||||
it.HashValues = make([]uint32, 0, len(pks))
|
||||
for _, pk := range pks {
|
||||
hash, _ := typeutil.Hash32Int64(pk)
|
||||
it.HashValues = append(it.HashValues, hash)
|
||||
}
|
||||
}
|
||||
|
||||
func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute")
|
||||
defer sp.Finish()
|
||||
it.Base.MsgType = commonpb.MsgType_Insert
|
||||
it.Base.SourceID = Params.ProxyCfg.ProxyID
|
||||
|
||||
it.result = &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -491,275 +283,199 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
Timestamp: it.EndTs(),
|
||||
}
|
||||
|
||||
collectionName := it.BaseInsertTask.CollectionName
|
||||
collectionName := it.CollectionName
|
||||
if err := validateCollectionName(collectionName); err != nil {
|
||||
log.Error("valid collection name failed", zap.String("collection name", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
partitionTag := it.BaseInsertTask.PartitionName
|
||||
partitionTag := it.PartitionName
|
||||
if err := validatePartitionTag(partitionTag, true); err != nil {
|
||||
log.Error("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||
log.Debug("Proxy Insert PreExecute", zap.Any("collSchema", collSchema))
|
||||
if err != nil {
|
||||
log.Error("get collection schema from global meta cache failed", zap.String("collection name", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.schema = collSchema
|
||||
|
||||
err = it.checkRowNums()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rowNums := uint32(it.NRows())
|
||||
// set insertTask.rowIDs
|
||||
var rowIDBegin UniqueID
|
||||
var rowIDEnd UniqueID
|
||||
tr := timerecord.NewTimeRecorder("applyPK")
|
||||
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(rowNums)
|
||||
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(tr.ElapseSpan()))
|
||||
|
||||
err = it.checkFieldAutoIDAndHashPK()
|
||||
if err != nil {
|
||||
return err
|
||||
it.RowIDs = make([]UniqueID, rowNums)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
offset := i - rowIDBegin
|
||||
it.RowIDs[offset] = i
|
||||
}
|
||||
|
||||
it.BaseInsertTask.InsertRequest.Version = internalpb.InsertDataVersion_ColumnBased
|
||||
it.BaseInsertTask.InsertRequest.FieldsData = it.req.GetFieldsData()
|
||||
it.BaseInsertTask.InsertRequest.NumRows = uint64(it.req.GetNumRows())
|
||||
err = typeutil.FillFieldBySchema(it.BaseInsertTask.InsertRequest.GetFieldsData(), collSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowNum := it.req.NumRows
|
||||
// set insertTask.timeStamps
|
||||
rowNum := it.NRows()
|
||||
it.Timestamps = make([]uint64, rowNum)
|
||||
for index := range it.Timestamps {
|
||||
it.Timestamps[index] = it.BeginTimestamp
|
||||
}
|
||||
|
||||
// set result.SuccIndex
|
||||
sliceIndex := make([]uint32, rowNums)
|
||||
for i := uint32(0); i < rowNums; i++ {
|
||||
sliceIndex[i] = i
|
||||
}
|
||||
it.result.SuccIndex = sliceIndex
|
||||
|
||||
// check primaryFieldData whether autoID is true or not
|
||||
// set rowIDs as primary data if autoID == true
|
||||
err = it.checkPrimaryFieldData()
|
||||
if err != nil {
|
||||
log.Error("check primary field data and hash primary key failed", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// set field ID to insert field data
|
||||
err = fillFieldIDBySchema(it.GetFieldsData(), collSchema)
|
||||
if err != nil {
|
||||
log.Error("set fieldID to fieldData failed", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// check that all field's number rows are equal
|
||||
if err = it.CheckAligned(); err != nil {
|
||||
log.Error("field data is not aligned", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Proxy Insert PreExecute done", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) _assignSegmentID(stream msgstream.MsgStream, pack *msgstream.MsgPack) (*msgstream.MsgPack, error) {
|
||||
newPack := &msgstream.MsgPack{
|
||||
BeginTs: pack.BeginTs,
|
||||
EndTs: pack.EndTs,
|
||||
StartPositions: pack.StartPositions,
|
||||
EndPositions: pack.EndPositions,
|
||||
Msgs: nil,
|
||||
}
|
||||
tsMsgs := pack.Msgs
|
||||
hashKeys := stream.ComputeProduceChannelIndexes(tsMsgs)
|
||||
if len(hashKeys) == 0 {
|
||||
return nil, fmt.Errorf("the length of hashKeys is 0")
|
||||
}
|
||||
reqID := it.Base.MsgID
|
||||
channelCountMap := make(map[int32]uint32) // channelID to count
|
||||
channelMaxTSMap := make(map[int32]Timestamp) // channelID to max Timestamp
|
||||
channelNames, err := it.chMgr.getVChannels(it.GetCollectionID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("_assignSemgentID, produceChannels:", zap.Any("Channels", channelNames))
|
||||
func (it *insertTask) assignSegmentID(channelNames []string) (*msgstream.MsgPack, error) {
|
||||
threshold := Params.PulsarCfg.MaxMessageSize
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
if request.Type() != commonpb.MsgType_Insert {
|
||||
return nil, fmt.Errorf("msg's must be Insert")
|
||||
}
|
||||
insertRequest, ok := request.(*msgstream.InsertMsg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("msg's must be Insert")
|
||||
}
|
||||
|
||||
keys := hashKeys[i]
|
||||
keysLen := len(keys)
|
||||
|
||||
if !insertRequest.CheckAligned() {
|
||||
return nil,
|
||||
fmt.Errorf("the length of timestamps(%d), rowIDs(%d) and num_rows(%d) are not equal",
|
||||
len(insertRequest.GetTimestamps()),
|
||||
len(insertRequest.GetRowIDs()),
|
||||
insertRequest.NRows())
|
||||
}
|
||||
if uint64(keysLen) != insertRequest.NRows() {
|
||||
return nil,
|
||||
fmt.Errorf(
|
||||
"the length of hashValue(%d), num_rows(%d) are not equal",
|
||||
keysLen, insertRequest.NRows())
|
||||
}
|
||||
for idx, channelID := range keys {
|
||||
channelCountMap[channelID]++
|
||||
if _, ok := channelMaxTSMap[channelID]; !ok {
|
||||
channelMaxTSMap[channelID] = typeutil.ZeroTimestamp
|
||||
}
|
||||
ts := insertRequest.Timestamps[idx]
|
||||
if channelMaxTSMap[channelID] < ts {
|
||||
channelMaxTSMap[channelID] = ts
|
||||
}
|
||||
}
|
||||
result := &msgstream.MsgPack{
|
||||
BeginTs: it.BeginTs(),
|
||||
EndTs: it.EndTs(),
|
||||
}
|
||||
|
||||
reqSegCountMap := make(map[int32]map[UniqueID]uint32)
|
||||
for channelID, count := range channelCountMap {
|
||||
ts, ok := channelMaxTSMap[channelID]
|
||||
if !ok {
|
||||
ts = typeutil.ZeroTimestamp
|
||||
log.Debug("Warning: did not get max Timestamp!")
|
||||
}
|
||||
// generate hash value for every primary key
|
||||
if len(it.HashValues) != 0 {
|
||||
log.Warn("the hashvalues passed through client is not supported now, and will be overwritten")
|
||||
}
|
||||
it.HashValues = typeutil.HashPK2Channels(it.result.IDs, channelNames)
|
||||
// groupedHashKeys represents the dmChannel index
|
||||
channel2RowOffsets := make(map[string][]int) // channelName to count
|
||||
channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp
|
||||
|
||||
// assert len(it.hashValues) < maxInt
|
||||
for offset, channelID := range it.HashValues {
|
||||
channelName := channelNames[channelID]
|
||||
if channelName == "" {
|
||||
return nil, fmt.Errorf("proxy, repack_func, can not found channelName")
|
||||
if _, ok := channel2RowOffsets[channelName]; !ok {
|
||||
channel2RowOffsets[channelName] = []int{}
|
||||
}
|
||||
mapInfo, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, count, ts)
|
||||
channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset)
|
||||
|
||||
if _, ok := channelMaxTSMap[channelName]; !ok {
|
||||
channelMaxTSMap[channelName] = typeutil.ZeroTimestamp
|
||||
}
|
||||
ts := it.Timestamps[offset]
|
||||
if channelMaxTSMap[channelName] < ts {
|
||||
channelMaxTSMap[channelName] = ts
|
||||
}
|
||||
}
|
||||
|
||||
// create empty insert message
|
||||
createInsertMsg := func(segmentID UniqueID, channelName string) *msgstream.InsertMsg {
|
||||
insertReq := internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: it.Base.MsgID,
|
||||
Timestamp: it.BeginTimestamp, // entity's timestamp was set to equal it.BeginTimestamp in preExecute()
|
||||
SourceID: it.Base.SourceID,
|
||||
},
|
||||
CollectionID: it.CollectionID,
|
||||
PartitionID: it.PartitionID,
|
||||
CollectionName: it.CollectionName,
|
||||
PartitionName: it.PartitionName,
|
||||
SegmentID: segmentID,
|
||||
ShardName: channelName,
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
}
|
||||
insertReq.FieldsData = make([]*schemapb.FieldData, len(it.GetFieldsData()))
|
||||
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: it.TraceCtx(),
|
||||
},
|
||||
InsertRequest: insertReq,
|
||||
}
|
||||
|
||||
return insertMsg
|
||||
}
|
||||
|
||||
// repack the row data corresponding to the offset to insertMsg
|
||||
getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, mexMessageSize int) ([]msgstream.TsMsg, error) {
|
||||
repackedMsgs := make([]msgstream.TsMsg, 0)
|
||||
requestSize := 0
|
||||
insertMsg := createInsertMsg(segmentID, channelName)
|
||||
for _, offset := range rowOffsets {
|
||||
curRowMessageSize, err := typeutil.EstimateEntitySize(it.InsertRequest.GetFieldsData(), offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if insertMsg's size is greater than the threshold, split into multiple insertMsgs
|
||||
if requestSize+curRowMessageSize >= mexMessageSize {
|
||||
repackedMsgs = append(repackedMsgs, insertMsg)
|
||||
insertMsg = createInsertMsg(segmentID, channelName)
|
||||
requestSize = 0
|
||||
}
|
||||
|
||||
typeutil.AppendFieldData(insertMsg.FieldsData, it.GetFieldsData(), int64(offset))
|
||||
insertMsg.HashValues = append(insertMsg.HashValues, it.HashValues[offset])
|
||||
insertMsg.Timestamps = append(insertMsg.Timestamps, it.Timestamps[offset])
|
||||
insertMsg.RowIDs = append(insertMsg.RowIDs, it.RowIDs[offset])
|
||||
insertMsg.NumRows++
|
||||
requestSize += curRowMessageSize
|
||||
}
|
||||
repackedMsgs = append(repackedMsgs, insertMsg)
|
||||
|
||||
return repackedMsgs, nil
|
||||
}
|
||||
|
||||
// get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID
|
||||
for channelName, rowOffsets := range channel2RowOffsets {
|
||||
assignedSegmentInfos, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName])
|
||||
if err != nil {
|
||||
log.Debug("insertTask.go", zap.Any("MapInfo", mapInfo),
|
||||
log.Error("allocate segmentID for insert data failed",
|
||||
zap.Int64("collectionID", it.CollectionID),
|
||||
zap.String("channel name", channelName),
|
||||
zap.Int("allocate count", len(rowOffsets)),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
reqSegCountMap[channelID] = make(map[UniqueID]uint32)
|
||||
reqSegCountMap[channelID] = mapInfo
|
||||
log.Debug("Proxy", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo))
|
||||
}
|
||||
|
||||
reqSegAccumulateCountMap := make(map[int32][]uint32)
|
||||
reqSegIDMap := make(map[int32][]UniqueID)
|
||||
reqSegAllocateCounter := make(map[int32]uint32)
|
||||
|
||||
for channelID, segInfo := range reqSegCountMap {
|
||||
reqSegAllocateCounter[channelID] = 0
|
||||
keys := make([]UniqueID, len(segInfo))
|
||||
i := 0
|
||||
for key := range segInfo {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
|
||||
accumulate := uint32(0)
|
||||
for _, key := range keys {
|
||||
accumulate += segInfo[key]
|
||||
if _, ok := reqSegAccumulateCountMap[channelID]; !ok {
|
||||
reqSegAccumulateCountMap[channelID] = make([]uint32, 0)
|
||||
startPos := 0
|
||||
for segmentID, count := range assignedSegmentInfos {
|
||||
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
|
||||
insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold)
|
||||
if err != nil {
|
||||
log.Error("repack insert data to insert msgs failed",
|
||||
zap.Int64("collectionID", it.CollectionID),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
reqSegAccumulateCountMap[channelID] = append(
|
||||
reqSegAccumulateCountMap[channelID],
|
||||
accumulate,
|
||||
)
|
||||
if _, ok := reqSegIDMap[channelID]; !ok {
|
||||
reqSegIDMap[channelID] = make([]UniqueID, 0)
|
||||
}
|
||||
reqSegIDMap[channelID] = append(
|
||||
reqSegIDMap[channelID],
|
||||
key,
|
||||
)
|
||||
result.Msgs = append(result.Msgs, insertMsgs...)
|
||||
startPos += int(count)
|
||||
}
|
||||
}
|
||||
|
||||
var getSegmentID = func(channelID int32) UniqueID {
|
||||
reqSegAllocateCounter[channelID]++
|
||||
cur := reqSegAllocateCounter[channelID]
|
||||
accumulateSlice := reqSegAccumulateCountMap[channelID]
|
||||
segIDSlice := reqSegIDMap[channelID]
|
||||
for index, count := range accumulateSlice {
|
||||
if cur <= count {
|
||||
return segIDSlice[index]
|
||||
}
|
||||
}
|
||||
log.Warn("Can't Found SegmentID", zap.Any("reqSegAllocateCounter", reqSegAllocateCounter))
|
||||
return 0
|
||||
}
|
||||
|
||||
threshold := Params.PulsarCfg.MaxMessageSize
|
||||
// not accurate
|
||||
/* #nosec G103 */
|
||||
getFixedSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int {
|
||||
size := 0
|
||||
|
||||
size += int(unsafe.Sizeof(*msg.Base))
|
||||
size += int(unsafe.Sizeof(msg.DbName))
|
||||
size += int(unsafe.Sizeof(msg.CollectionName))
|
||||
size += int(unsafe.Sizeof(msg.PartitionName))
|
||||
size += int(unsafe.Sizeof(msg.DbID))
|
||||
size += int(unsafe.Sizeof(msg.CollectionID))
|
||||
size += int(unsafe.Sizeof(msg.PartitionID))
|
||||
size += int(unsafe.Sizeof(msg.SegmentID))
|
||||
size += int(unsafe.Sizeof(msg.ShardName))
|
||||
size += int(unsafe.Sizeof(msg.Timestamps))
|
||||
size += int(unsafe.Sizeof(msg.RowIDs))
|
||||
return size
|
||||
}
|
||||
|
||||
sizePerRow, _ := typeutil.EstimateSizePerRecord(it.schema)
|
||||
result := make(map[int32]msgstream.TsMsg)
|
||||
curMsgSizeMap := make(map[int32]int)
|
||||
|
||||
for i, request := range tsMsgs {
|
||||
insertRequest := request.(*msgstream.InsertMsg)
|
||||
keys := hashKeys[i]
|
||||
collectionName := insertRequest.CollectionName
|
||||
collectionID := insertRequest.CollectionID
|
||||
partitionID := insertRequest.PartitionID
|
||||
partitionName := insertRequest.PartitionName
|
||||
proxyID := insertRequest.Base.SourceID
|
||||
for index, key := range keys {
|
||||
ts := insertRequest.Timestamps[index]
|
||||
rowID := insertRequest.RowIDs[index]
|
||||
segmentID := getSegmentID(key)
|
||||
if segmentID == 0 {
|
||||
return nil, fmt.Errorf("get SegmentID failed, segmentID is zero")
|
||||
}
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
sliceRequest := internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: reqID,
|
||||
Timestamp: ts,
|
||||
SourceID: proxyID,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
SegmentID: segmentID,
|
||||
ShardName: channelNames[key],
|
||||
}
|
||||
|
||||
sliceRequest.Version = internalpb.InsertDataVersion_ColumnBased
|
||||
sliceRequest.NumRows = 0
|
||||
sliceRequest.FieldsData = make([]*schemapb.FieldData, len(it.BaseInsertTask.InsertRequest.GetFieldsData()))
|
||||
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: request.TraceCtx(),
|
||||
},
|
||||
InsertRequest: sliceRequest,
|
||||
}
|
||||
result[key] = insertMsg
|
||||
curMsgSizeMap[key] = getFixedSizeOfInsertMsg(insertMsg)
|
||||
}
|
||||
|
||||
curMsg := result[key].(*msgstream.InsertMsg)
|
||||
curMsgSize := curMsgSizeMap[key]
|
||||
|
||||
curMsg.HashValues = append(curMsg.HashValues, insertRequest.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, ts)
|
||||
curMsg.RowIDs = append(curMsg.RowIDs, rowID)
|
||||
|
||||
typeutil.AppendFieldData(curMsg.FieldsData, it.BaseInsertTask.InsertRequest.GetFieldsData(), int64(index))
|
||||
curMsg.NumRows++
|
||||
curMsgSize += sizePerRow
|
||||
|
||||
if curMsgSize >= threshold {
|
||||
newPack.Msgs = append(newPack.Msgs, curMsg)
|
||||
delete(result, key)
|
||||
curMsgSize = 0
|
||||
}
|
||||
|
||||
curMsgSizeMap[key] = curMsgSize
|
||||
}
|
||||
}
|
||||
for _, msg := range result {
|
||||
if msg != nil {
|
||||
newPack.Msgs = append(newPack.Msgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
return newPack, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (it *insertTask) Execute(ctx context.Context) error {
|
||||
|
@ -769,7 +485,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID()))
|
||||
defer tr.Elapse("done")
|
||||
|
||||
collectionName := it.BaseInsertTask.CollectionName
|
||||
collectionName := it.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -790,16 +506,6 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
it.PartitionID = partitionID
|
||||
tr.Record("get collection id & partition id from cache")
|
||||
|
||||
var tsMsg msgstream.TsMsg = &it.BaseInsertTask
|
||||
it.BaseMsg.Ctx = ctx
|
||||
msgPack := msgstream.MsgPack{
|
||||
BeginTs: it.BeginTs(),
|
||||
EndTs: it.EndTs(),
|
||||
Msgs: make([]msgstream.TsMsg, 1),
|
||||
}
|
||||
|
||||
msgPack.Msgs[0] = tsMsg
|
||||
|
||||
stream, err := it.chMgr.getDMLStream(collID)
|
||||
if err != nil {
|
||||
err = it.chMgr.createDMLMsgStream(collID)
|
||||
|
@ -817,16 +523,27 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
}
|
||||
tr.Record("get used message stream")
|
||||
|
||||
// Assign SegmentID
|
||||
var pack *msgstream.MsgPack
|
||||
pack, err = it._assignSegmentID(stream, &msgPack)
|
||||
channelNames, err := it.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
log.Error("get vChannels failed", zap.Int64("msgID", it.Base.MsgID), zap.Int64("collectionID", collID), zap.Error(err))
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
|
||||
// assign segmentID for insert data and repack data by segmentID
|
||||
msgPack, err := it.assignSegmentID(channelNames)
|
||||
if err != nil {
|
||||
log.Error("assign segmentID and repack insert data failed", zap.Int64("msgID", it.Base.MsgID), zap.Int64("collectionID", collID), zap.Error(err))
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
log.Debug("assign segmentID for insert data success", zap.Int64("msgID", it.Base.MsgID), zap.Int64("collectionID", collID), zap.String("collection name", it.CollectionName))
|
||||
tr.Record("assign segment id")
|
||||
|
||||
tr.Record("sendInsertMsg")
|
||||
err = stream.Produce(pack)
|
||||
err = stream.Produce(msgPack)
|
||||
if err != nil {
|
||||
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
it.result.Status.Reason = err.Error()
|
||||
|
@ -835,6 +552,8 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
sendMsgDur := tr.Record("send insert request to message stream")
|
||||
metrics.ProxySendInsertReqLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(sendMsgDur.Milliseconds()))
|
||||
|
||||
log.Debug("Proxy Insert Execute done", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -4549,8 +4268,6 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
dt.result.DeleteCnt = int64(len(primaryKeys))
|
||||
|
||||
dt.HashPK(primaryKeys)
|
||||
|
||||
rowNum := len(primaryKeys)
|
||||
dt.Timestamps = make([]uint64, rowNum)
|
||||
for index := range dt.Timestamps {
|
||||
|
@ -4564,14 +4281,6 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
sp, ctx := trace.StartSpanFromContextWithOperationName(dt.ctx, "Proxy-Delete-Execute")
|
||||
defer sp.Finish()
|
||||
|
||||
var tsMsg msgstream.TsMsg = &dt.BaseDeleteTask
|
||||
msgPack := msgstream.MsgPack{
|
||||
BeginTs: dt.BeginTs(),
|
||||
EndTs: dt.EndTs(),
|
||||
Msgs: make([]msgstream.TsMsg, 1),
|
||||
}
|
||||
msgPack.Msgs[0] = tsMsg
|
||||
|
||||
collID := dt.DeleteRequest.CollectionID
|
||||
stream, err := dt.chMgr.getDMLStream(collID)
|
||||
if err != nil {
|
||||
|
@ -4588,64 +4297,66 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
result := make(map[int32]msgstream.TsMsg)
|
||||
hashKeys := stream.ComputeProduceChannelIndexes(msgPack.Msgs)
|
||||
// For each msg, assign PK to different message buckets by hash value of PK.
|
||||
for i, request := range msgPack.Msgs {
|
||||
deleteRequest := request.(*msgstream.DeleteMsg)
|
||||
keys := hashKeys[i]
|
||||
collectionName := deleteRequest.CollectionName
|
||||
collectionID := deleteRequest.CollectionID
|
||||
partitionID := deleteRequest.PartitionID
|
||||
partitionName := deleteRequest.PartitionName
|
||||
proxyID := deleteRequest.Base.SourceID
|
||||
for index, key := range keys {
|
||||
ts := deleteRequest.Timestamps[index]
|
||||
pks := deleteRequest.PrimaryKeys[index]
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
sliceRequest := internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Delete,
|
||||
MsgID: dt.Base.MsgID,
|
||||
Timestamp: ts,
|
||||
SourceID: proxyID,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
}
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
},
|
||||
DeleteRequest: sliceRequest,
|
||||
}
|
||||
result[key] = deleteMsg
|
||||
|
||||
// hash primary keys to channels
|
||||
channelNames, err := dt.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
log.Error("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err))
|
||||
dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
dt.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
dt.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames)
|
||||
|
||||
// repack delete msg by dmChannel
|
||||
result := make(map[uint32]msgstream.TsMsg)
|
||||
collectionName := dt.CollectionName
|
||||
collectionID := dt.CollectionID
|
||||
partitionID := dt.PartitionID
|
||||
partitionName := dt.PartitionName
|
||||
proxyID := dt.Base.SourceID
|
||||
for index, key := range dt.HashValues {
|
||||
ts := dt.Timestamps[index]
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
sliceRequest := internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Delete,
|
||||
MsgID: dt.Base.MsgID,
|
||||
Timestamp: ts,
|
||||
SourceID: proxyID,
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
}
|
||||
curMsg := result[key].(*msgstream.DeleteMsg)
|
||||
curMsg.HashValues = append(curMsg.HashValues, deleteRequest.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, ts)
|
||||
curMsg.PrimaryKeys = append(curMsg.PrimaryKeys, pks)
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
},
|
||||
DeleteRequest: sliceRequest,
|
||||
}
|
||||
result[key] = deleteMsg
|
||||
}
|
||||
curMsg := result[key].(*msgstream.DeleteMsg)
|
||||
curMsg.HashValues = append(curMsg.HashValues, dt.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, dt.Timestamps[index])
|
||||
curMsg.PrimaryKeys = append(curMsg.PrimaryKeys, dt.PrimaryKeys[index])
|
||||
}
|
||||
|
||||
newPack := &msgstream.MsgPack{
|
||||
BeginTs: msgPack.BeginTs,
|
||||
EndTs: msgPack.EndTs,
|
||||
StartPositions: msgPack.StartPositions,
|
||||
EndPositions: msgPack.EndPositions,
|
||||
Msgs: make([]msgstream.TsMsg, 0),
|
||||
// send delete request to log broker
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: dt.BeginTs(),
|
||||
EndTs: dt.EndTs(),
|
||||
}
|
||||
|
||||
for _, msg := range result {
|
||||
if msg != nil {
|
||||
newPack.Msgs = append(newPack.Msgs, msg)
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
err = stream.Produce(newPack)
|
||||
err = stream.Produce(msgPack)
|
||||
if err != nil {
|
||||
dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
||||
dt.result.Status.Reason = err.Error()
|
||||
|
@ -4658,17 +4369,6 @@ func (dt *deleteTask) PostExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (dt *deleteTask) HashPK(pks []int64) {
|
||||
if len(dt.HashValues) != 0 {
|
||||
log.Warn("the hashvalues passed through client is not supported now, and will be overwritten")
|
||||
}
|
||||
dt.HashValues = make([]uint32, 0, len(pks))
|
||||
for _, pk := range pks {
|
||||
hash, _ := typeutil.Hash32Int64(pk)
|
||||
dt.HashValues = append(dt.HashValues, hash)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAliasTask contains task information of CreateAlias
|
||||
type CreateAliasTask struct {
|
||||
Condition
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -398,7 +398,7 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
|
|||
for i := range schema.Fields {
|
||||
name := schema.Fields[i].Name
|
||||
dType := schema.Fields[i].DataType
|
||||
isVec := (dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector)
|
||||
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector
|
||||
if isVec && vecExist && !enableMultipleVectorFields {
|
||||
return fmt.Errorf(
|
||||
"multiple vector fields is not supported, fields name: %s, %s",
|
||||
|
@ -413,3 +413,98 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPrimaryFieldData get primary field data from all field data inserted from sdk
|
||||
func getPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schemapb.FieldSchema) (*schemapb.FieldData, error) {
|
||||
primaryFieldName := primaryFieldSchema.Name
|
||||
|
||||
var primaryFieldData *schemapb.FieldData
|
||||
for _, field := range datas {
|
||||
if field.FieldName == primaryFieldName {
|
||||
if primaryFieldSchema.AutoID {
|
||||
return nil, fmt.Errorf("autoID field %v does not require data", primaryFieldName)
|
||||
}
|
||||
primaryFieldData = field
|
||||
}
|
||||
}
|
||||
|
||||
if primaryFieldData == nil {
|
||||
return nil, fmt.Errorf("can't find data for primary field %v", primaryFieldName)
|
||||
}
|
||||
|
||||
return primaryFieldData, nil
|
||||
}
|
||||
|
||||
// parsePrimaryFieldData2IDs get IDs to fill grpc result, for example insert request, delete request etc.
|
||||
func parsePrimaryFieldData2IDs(fieldData *schemapb.FieldData) (*schemapb.IDs, error) {
|
||||
primaryData := &schemapb.IDs{}
|
||||
switch fieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := fieldData.GetScalars()
|
||||
switch scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_LongData:
|
||||
primaryData.IdField = &schemapb.IDs_IntId{
|
||||
IntId: scalarField.GetLongData(),
|
||||
}
|
||||
case *schemapb.ScalarField_StringData:
|
||||
primaryData.IdField = &schemapb.IDs_StrId{
|
||||
StrId: scalarField.GetStringData(),
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("currently only support DataType Int64 or VarChar as PrimaryField")
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("currently only support vector field as PrimaryField")
|
||||
}
|
||||
|
||||
return primaryData, nil
|
||||
}
|
||||
|
||||
// autoGenPrimaryFieldData generate primary data when autoID == true
|
||||
func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}) (*schemapb.FieldData, error) {
|
||||
var fieldData schemapb.FieldData
|
||||
fieldData.FieldName = fieldSchema.Name
|
||||
fieldData.Type = fieldSchema.DataType
|
||||
switch data := data.(type) {
|
||||
case []int64:
|
||||
if fieldSchema.DataType != schemapb.DataType_Int64 {
|
||||
return nil, errors.New("the data type of the data and the schema do not match")
|
||||
}
|
||||
fieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("currently only support autoID for int64 PrimaryField")
|
||||
}
|
||||
|
||||
return &fieldData, nil
|
||||
}
|
||||
|
||||
// fillFieldIDBySchema set fieldID to fieldData according FieldSchemas
|
||||
func fillFieldIDBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error {
|
||||
if len(columns) != len(schema.GetFields()) {
|
||||
return fmt.Errorf("len(columns) mismatch the len(fields), len(columns): %d, len(fields): %d",
|
||||
len(columns), len(schema.GetFields()))
|
||||
}
|
||||
fieldName2Schema := make(map[string]*schemapb.FieldSchema)
|
||||
for _, field := range schema.GetFields() {
|
||||
fieldName2Schema[field.Name] = field
|
||||
}
|
||||
|
||||
for _, fieldData := range columns {
|
||||
if fieldSchema, ok := fieldName2Schema[fieldData.FieldName]; ok {
|
||||
fieldData.FieldId = fieldSchema.FieldID
|
||||
fieldData.Type = fieldSchema.DataType
|
||||
} else {
|
||||
return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -502,3 +502,28 @@ func TestValidateMultipleVectorFields(t *testing.T) {
|
|||
assert.Error(t, validateMultipleVectorFields(schema3))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillFieldIDBySchema(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{}
|
||||
columns := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "TestFillFieldIDBySchema",
|
||||
},
|
||||
}
|
||||
|
||||
// length mismatch
|
||||
assert.Error(t, fillFieldIDBySchema(columns, schema))
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "TestFillFieldIDBySchema",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
FieldID: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NoError(t, fillFieldIDBySchema(columns, schema))
|
||||
assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Int64, columns[0].Type)
|
||||
assert.Equal(t, int64(1), columns[0].FieldId)
|
||||
}
|
|
@ -145,7 +145,7 @@ func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg
|
|||
|
||||
// filterInvalidInsertMessage would filter out invalid insert messages
|
||||
func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg {
|
||||
if !msg.CheckAligned() {
|
||||
if err := msg.CheckAligned(); err != nil {
|
||||
// TODO: what if the messages are misaligned? Here, we ignore those messages and print error
|
||||
log.Warn("Error, misaligned messages detected")
|
||||
return nil
|
||||
|
|
|
@ -102,15 +102,6 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
|
|||
|
||||
// 1. hash insertMessages to insertData
|
||||
for _, insertMsg := range iMsg.insertMessages {
|
||||
if insertMsg.IsColumnBased() {
|
||||
var err error
|
||||
insertMsg.RowData, err = typeutil.TransferColumnBasedDataToRowBasedData(insertMsg.FieldsData)
|
||||
if err != nil {
|
||||
log.Error("failed to transfer column-based data to row-based data", zap.Error(err))
|
||||
return []Msg{}
|
||||
}
|
||||
}
|
||||
|
||||
// if loadType is loadCollection, check if partition exists, if not, create partition
|
||||
col, err := iNode.streamingReplica.getCollectionByID(insertMsg.CollectionID)
|
||||
if err != nil {
|
||||
|
@ -134,6 +125,15 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
|
|||
}
|
||||
}
|
||||
|
||||
// trans column field data to row data
|
||||
if insertMsg.IsColumnBased() {
|
||||
insertMsg.RowData, err = typeutil.TransferColumnBasedDataToRowBasedData(col.schema, insertMsg.FieldsData)
|
||||
if err != nil {
|
||||
log.Error("failed to transfer column-based data to row-based data", zap.Error(err))
|
||||
return []Msg{}
|
||||
}
|
||||
}
|
||||
|
||||
iData.insertIDs[insertMsg.SegmentID] = append(iData.insertIDs[insertMsg.SegmentID], insertMsg.RowIDs...)
|
||||
iData.insertTimestamps[insertMsg.SegmentID] = append(iData.insertTimestamps[insertMsg.SegmentID], insertMsg.Timestamps...)
|
||||
// using insertMsg.RowData is valid here, since we have already transferred the column-based data.
|
||||
|
@ -345,7 +345,7 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID, wg *
|
|||
// TODO: remove this function to proper file
|
||||
// getPrimaryKeys would get primary keys by insert messages
|
||||
func getPrimaryKeys(msg *msgstream.InsertMsg, streamingReplica ReplicaInterface) ([]int64, error) {
|
||||
if !msg.CheckAligned() {
|
||||
if err := msg.CheckAligned(); err != nil {
|
||||
log.Warn("misaligned messages detected")
|
||||
return nil, errors.New("misaligned messages detected")
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ enum ColumnType : int {
|
|||
FLOAT = 10,
|
||||
DOUBLE = 11,
|
||||
STRING = 20,
|
||||
VARCHAR = 21,
|
||||
VECTOR_BINARY = 100,
|
||||
VECTOR_FLOAT = 101
|
||||
};
|
||||
|
|
|
@ -76,6 +76,7 @@ CPayloadWriter NewPayloadWriter(int columnType) {
|
|||
p->schema = arrow::schema({arrow::field("val", arrow::float64())});
|
||||
break;
|
||||
}
|
||||
case ColumnType::VARCHAR:
|
||||
case ColumnType::STRING : {
|
||||
p->columnType = ColumnType::STRING;
|
||||
p->builder = std::make_shared<arrow::StringBuilder>();
|
||||
|
@ -384,6 +385,7 @@ CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_siz
|
|||
case ColumnType::FLOAT :
|
||||
case ColumnType::DOUBLE :
|
||||
case ColumnType::STRING :
|
||||
case ColumnType::VARCHAR:
|
||||
case ColumnType::VECTOR_BINARY :
|
||||
case ColumnType::VECTOR_FLOAT : {
|
||||
break;
|
||||
|
|
|
@ -360,7 +360,7 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique
|
|||
return nil, nil, err
|
||||
}
|
||||
writer.AddExtra(originalSizeKey, fmt.Sprintf("%v", singleData.(*DoubleFieldData).GetMemorySize()))
|
||||
case schemapb.DataType_String:
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
for _, singleString := range singleData.(*StringFieldData).Data {
|
||||
err = eventWriter.AddOneStringToPayload(singleString)
|
||||
if err != nil {
|
||||
|
@ -416,10 +416,9 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique
|
|||
writer.Close()
|
||||
|
||||
// stats fields
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_Int64:
|
||||
if field.GetIsPrimaryKey() {
|
||||
statsWriter := &StatsWriter{}
|
||||
err = statsWriter.StatsInt64(field.FieldID, field.IsPrimaryKey, singleData.(*Int64FieldData).Data)
|
||||
err = statsWriter.generatePrimaryKeyStats(field.FieldID, field.DataType, singleData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -621,7 +620,7 @@ func (insertCodec *InsertCodec) DeserializeAll(blobs []*Blob) (
|
|||
totalLength += length
|
||||
doubleFieldData.NumRows = append(doubleFieldData.NumRows, int64(length))
|
||||
resultData.Data[fieldID] = doubleFieldData
|
||||
case schemapb.DataType_String:
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
if resultData.Data[fieldID] == nil {
|
||||
resultData.Data[fieldID] = &StringFieldData{}
|
||||
}
|
||||
|
|
|
@ -77,7 +77,7 @@ func (ds *DataSorter) Swap(i, j int) {
|
|||
case schemapb.DataType_Double:
|
||||
data := singleData.(*DoubleFieldData).Data
|
||||
data[i], data[j] = data[j], data[i]
|
||||
case schemapb.DataType_String:
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
data := singleData.(*StringFieldData).Data
|
||||
data[i], data[j] = data[j], data[i]
|
||||
case schemapb.DataType_BinaryVector:
|
||||
|
|
|
@ -145,7 +145,7 @@ func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error {
|
|||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddDoubleToPayload(val)
|
||||
case schemapb.DataType_String:
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
val, ok := msgs.(string)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
|
@ -387,7 +387,7 @@ func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error)
|
|||
switch len(idx) {
|
||||
case 1:
|
||||
switch r.colType {
|
||||
case schemapb.DataType_String:
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
val, err := r.GetOneStringFromPayload(idx[0])
|
||||
return val, 0, err
|
||||
default:
|
||||
|
@ -573,7 +573,7 @@ func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
|
|||
}
|
||||
|
||||
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
|
||||
if r.colType != schemapb.DataType_String {
|
||||
if r.colType != schemapb.DataType_String && r.colType != schemapb.DataType_VarChar {
|
||||
return "", errors.New("incorrect data type")
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,233 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
)
|
||||
|
||||
type PrimaryKey interface {
|
||||
GT(key PrimaryKey) bool
|
||||
GE(key PrimaryKey) bool
|
||||
LT(key PrimaryKey) bool
|
||||
LE(key PrimaryKey) bool
|
||||
EQ(key PrimaryKey) bool
|
||||
MarshalJSON() ([]byte, error)
|
||||
UnmarshalJSON(data []byte) error
|
||||
SetValue(interface{}) error
|
||||
}
|
||||
|
||||
type Int64PrimaryKey struct {
|
||||
Value int64 `json:"pkValue"`
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) GT(key PrimaryKey) bool {
|
||||
pk, ok := key.(*Int64PrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not int64")
|
||||
return false
|
||||
}
|
||||
if ip.Value > pk.Value {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) GE(key PrimaryKey) bool {
|
||||
pk, ok := key.(*Int64PrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not int64")
|
||||
return false
|
||||
}
|
||||
if ip.Value >= pk.Value {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) LT(key PrimaryKey) bool {
|
||||
pk, ok := key.(*Int64PrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not int64")
|
||||
return false
|
||||
}
|
||||
|
||||
if ip.Value < pk.Value {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) LE(key PrimaryKey) bool {
|
||||
pk, ok := key.(*Int64PrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not int64")
|
||||
return false
|
||||
}
|
||||
|
||||
if ip.Value <= pk.Value {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) EQ(key PrimaryKey) bool {
|
||||
pk, ok := key.(*Int64PrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not int64")
|
||||
return false
|
||||
}
|
||||
|
||||
if ip.Value == pk.Value {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) MarshalJSON() ([]byte, error) {
|
||||
ret, err := json.Marshal(ip.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) UnmarshalJSON(data []byte) error {
|
||||
err := json.Unmarshal(data, &ip.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ip *Int64PrimaryKey) SetValue(data interface{}) error {
|
||||
value, ok := data.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrong type value when setValue for Int64PrimaryKey")
|
||||
}
|
||||
|
||||
ip.Value = value
|
||||
return nil
|
||||
}
|
||||
|
||||
type StringPrimaryKey struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) GT(key PrimaryKey) bool {
|
||||
pk, ok := key.(*StringPrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not string")
|
||||
return false
|
||||
}
|
||||
if strings.Compare(sp.Value, pk.Value) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) GE(key PrimaryKey) bool {
|
||||
pk, ok := key.(*StringPrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not string")
|
||||
return false
|
||||
}
|
||||
if strings.Compare(sp.Value, pk.Value) >= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) LT(key PrimaryKey) bool {
|
||||
pk, ok := key.(*StringPrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not string")
|
||||
return false
|
||||
}
|
||||
if strings.Compare(sp.Value, pk.Value) < 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) LE(key PrimaryKey) bool {
|
||||
pk, ok := key.(*StringPrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not string")
|
||||
return false
|
||||
}
|
||||
if strings.Compare(sp.Value, pk.Value) <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) EQ(key PrimaryKey) bool {
|
||||
pk, ok := key.(*StringPrimaryKey)
|
||||
if !ok {
|
||||
log.Warn("type of compared pk is not string")
|
||||
return false
|
||||
}
|
||||
if strings.Compare(sp.Value, pk.Value) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) MarshalJSON() ([]byte, error) {
|
||||
ret, err := json.Marshal(sp.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) UnmarshalJSON(data []byte) error {
|
||||
err := json.Unmarshal(data, &sp.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sp *StringPrimaryKey) SetValue(data interface{}) error {
|
||||
value, ok := data.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrong type value when setValue for StringPrimaryKey")
|
||||
}
|
||||
|
||||
sp.Value = value
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStringPrimaryKey(t *testing.T) {
|
||||
pk := &StringPrimaryKey{
|
||||
Value: "milvus",
|
||||
}
|
||||
|
||||
testPk := &StringPrimaryKey{
|
||||
Value: "milvus",
|
||||
}
|
||||
|
||||
// test GE
|
||||
assert.Equal(t, true, pk.GE(testPk))
|
||||
// test LE
|
||||
assert.Equal(t, true, pk.LE(testPk))
|
||||
// test EQ
|
||||
assert.Equal(t, true, pk.EQ(testPk))
|
||||
|
||||
// test GT
|
||||
err := testPk.SetValue("bivlus")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, true, pk.GT(testPk))
|
||||
|
||||
// test LT
|
||||
err = testPk.SetValue("mivlut")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, true, pk.LT(testPk))
|
||||
|
||||
t.Run("unmarshal", func(t *testing.T) {
|
||||
blob, err := json.Marshal(pk)
|
||||
assert.Nil(t, err)
|
||||
|
||||
unmarshalledPk := &StringPrimaryKey{}
|
||||
err = json.Unmarshal(blob, unmarshalledPk)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, pk.Value, unmarshalledPk.Value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInt64PrimaryKey(t *testing.T) {
|
||||
pk := &Int64PrimaryKey{
|
||||
Value: 100,
|
||||
}
|
||||
|
||||
testPk := &Int64PrimaryKey{
|
||||
Value: 100,
|
||||
}
|
||||
|
||||
// test GE
|
||||
assert.Equal(t, true, pk.GE(testPk))
|
||||
// test LE
|
||||
assert.Equal(t, true, pk.LE(testPk))
|
||||
// test EQ
|
||||
assert.Equal(t, true, pk.EQ(testPk))
|
||||
|
||||
// test GT
|
||||
err := testPk.SetValue(int64(10))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, true, pk.GT(testPk))
|
||||
|
||||
// test LT
|
||||
err = testPk.SetValue(int64(200))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, true, pk.LT(testPk))
|
||||
|
||||
t.Run("unmarshal", func(t *testing.T) {
|
||||
blob, err := json.Marshal(pk)
|
||||
assert.Nil(t, err)
|
||||
|
||||
unmarshalledPk := &Int64PrimaryKey{}
|
||||
err = json.Unmarshal(blob, unmarshalledPk)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, pk.Value, unmarshalledPk.Value)
|
||||
})
|
||||
}
|
|
@ -277,7 +277,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
|||
for i, v := range val {
|
||||
fmt.Printf("\t\t%d : %v\n", i, v)
|
||||
}
|
||||
case schemapb.DataType_String:
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
rows, err := reader.GetPayloadLengthFromReader()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -20,7 +20,9 @@ import (
|
|||
"encoding/json"
|
||||
|
||||
"github.com/bits-and-blooms/bloom/v3"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -29,12 +31,107 @@ const (
|
|||
maxBloomFalsePositive float64 = 0.005
|
||||
)
|
||||
|
||||
// Int64Stats contains statistics data for int64 column
|
||||
type Int64Stats struct {
|
||||
// PrimaryKeyStats contains statistics data for pk column
|
||||
type PrimaryKeyStats struct {
|
||||
FieldID int64 `json:"fieldID"`
|
||||
Max int64 `json:"max"`
|
||||
Min int64 `json:"min"`
|
||||
Max int64 `json:"max"` // useless, will delete
|
||||
Min int64 `json:"min"` //useless, will delete
|
||||
BF *bloom.BloomFilter `json:"bf"`
|
||||
PkType int64 `json:"pkType"`
|
||||
MaxPk PrimaryKey `json:"maxPk"`
|
||||
MinPk PrimaryKey `json:"minPk"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshal bytes to PrimaryKeyStats
|
||||
func (stats *PrimaryKeyStats) UnmarshalJSON(data []byte) error {
|
||||
var messageMap map[string]*json.RawMessage
|
||||
err := json.Unmarshal(data, &messageMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(*messageMap["fieldID"], &stats.FieldID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stats.PkType = int64(schemapb.DataType_Int64)
|
||||
var typeValue int64
|
||||
err = json.Unmarshal(*messageMap["pkType"], &typeValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// valid pkType
|
||||
if typeValue > 0 {
|
||||
stats.PkType = typeValue
|
||||
}
|
||||
|
||||
switch schemapb.DataType(stats.PkType) {
|
||||
case schemapb.DataType_Int64:
|
||||
stats.MaxPk = &Int64PrimaryKey{}
|
||||
stats.MinPk = &Int64PrimaryKey{}
|
||||
|
||||
// Compatible with versions that only support int64 type primary keys
|
||||
err = json.Unmarshal(*messageMap["max"], &stats.Max)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = stats.MaxPk.SetValue(stats.Max)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(*messageMap["min"], &stats.Min)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = stats.MinPk.SetValue(stats.Min)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_VarChar:
|
||||
stats.MaxPk = &StringPrimaryKey{}
|
||||
stats.MinPk = &StringPrimaryKey{}
|
||||
}
|
||||
|
||||
if maxPkMessage, ok := messageMap["maxPk"]; ok && maxPkMessage != nil {
|
||||
err = json.Unmarshal(*maxPkMessage, stats.MaxPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if minPkMessage, ok := messageMap["minPk"]; ok && minPkMessage != nil {
|
||||
err = json.Unmarshal(*minPkMessage, stats.MinPk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
stats.BF = bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive)
|
||||
if bfMessage, ok := messageMap["bf"]; ok && bfMessage != nil {
|
||||
err = stats.BF.UnmarshalJSON(*bfMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updatePk update minPk and maxPk value
|
||||
func (stats *PrimaryKeyStats) updatePk(pk PrimaryKey) {
|
||||
if stats.MinPk == nil {
|
||||
stats.MinPk = pk
|
||||
} else if stats.MinPk.GT(pk) {
|
||||
stats.MinPk = pk
|
||||
}
|
||||
|
||||
if stats.MaxPk == nil {
|
||||
stats.MaxPk = pk
|
||||
} else if stats.MaxPk.LT(pk) {
|
||||
stats.MaxPk = pk
|
||||
}
|
||||
}
|
||||
|
||||
// StatsWriter writes stats to buffer
|
||||
|
@ -47,26 +144,49 @@ func (sw *StatsWriter) GetBuffer() []byte {
|
|||
return sw.buffer
|
||||
}
|
||||
|
||||
// StatsInt64 writes Int64Stats from @msgs with @fieldID to @buffer
|
||||
func (sw *StatsWriter) StatsInt64(fieldID int64, isPrimaryKey bool, msgs []int64) error {
|
||||
if len(msgs) < 1 {
|
||||
// return error: msgs must has one element at least
|
||||
return nil
|
||||
// generatePrimaryKeyStats writes Int64Stats from @msgs with @fieldID to @buffer
|
||||
func (sw *StatsWriter) generatePrimaryKeyStats(fieldID int64, pkType schemapb.DataType, msgs FieldData) error {
|
||||
stats := &PrimaryKeyStats{
|
||||
FieldID: fieldID,
|
||||
PkType: int64(pkType),
|
||||
}
|
||||
|
||||
stats := &Int64Stats{
|
||||
FieldID: fieldID,
|
||||
Max: msgs[len(msgs)-1],
|
||||
Min: msgs[0],
|
||||
}
|
||||
if isPrimaryKey {
|
||||
stats.BF = bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive)
|
||||
stats.BF = bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive)
|
||||
switch fieldData := msgs.(type) {
|
||||
case *Int64FieldData:
|
||||
data := fieldData.Data
|
||||
if len(data) < 1 {
|
||||
// return error: msgs must has one element at least
|
||||
return nil
|
||||
}
|
||||
|
||||
b := make([]byte, 8)
|
||||
for _, msg := range msgs {
|
||||
common.Endian.PutUint64(b, uint64(msg))
|
||||
for _, int64Value := range data {
|
||||
pk := &Int64PrimaryKey{
|
||||
Value: int64Value,
|
||||
}
|
||||
stats.updatePk(pk)
|
||||
common.Endian.PutUint64(b, uint64(int64Value))
|
||||
stats.BF.Add(b)
|
||||
}
|
||||
case *StringFieldData:
|
||||
data := fieldData.Data
|
||||
if len(data) < 1 {
|
||||
// return error: msgs must has one element at least
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, str := range data {
|
||||
pk := &StringPrimaryKey{
|
||||
Value: str,
|
||||
}
|
||||
stats.updatePk(pk)
|
||||
stats.BF.Add([]byte(str))
|
||||
}
|
||||
default:
|
||||
//TODO::
|
||||
}
|
||||
|
||||
b, err := json.Marshal(stats)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -86,26 +206,27 @@ func (sr *StatsReader) SetBuffer(buffer []byte) {
|
|||
sr.buffer = buffer
|
||||
}
|
||||
|
||||
// GetInt64Stats returns buffer as Int64Stats
|
||||
func (sr *StatsReader) GetInt64Stats() (*Int64Stats, error) {
|
||||
stats := &Int64Stats{}
|
||||
// GetInt64Stats returns buffer as PrimaryKeyStats
|
||||
func (sr *StatsReader) GetPrimaryKeyStats() (*PrimaryKeyStats, error) {
|
||||
stats := &PrimaryKeyStats{}
|
||||
err := json.Unmarshal(sr.buffer, &stats)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// DeserializeStats deserialize @blobs as []*Int64Stats
|
||||
func DeserializeStats(blobs []*Blob) ([]*Int64Stats, error) {
|
||||
results := make([]*Int64Stats, 0, len(blobs))
|
||||
// DeserializeStats deserialize @blobs as []*PrimaryKeyStats
|
||||
func DeserializeStats(blobs []*Blob) ([]*PrimaryKeyStats, error) {
|
||||
results := make([]*PrimaryKeyStats, 0, len(blobs))
|
||||
for _, blob := range blobs {
|
||||
if blob.Value == nil {
|
||||
continue
|
||||
}
|
||||
sr := &StatsReader{}
|
||||
sr.SetBuffer(blob.Value)
|
||||
stats, err := sr.GetInt64Stats()
|
||||
stats, err := sr.GetPrimaryKeyStats()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -17,33 +17,117 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/rootcoord"
|
||||
"github.com/bits-and-blooms/bloom/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/rootcoord"
|
||||
)
|
||||
|
||||
func TestStatsWriter_StatsInt64(t *testing.T) {
|
||||
data := []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
func TestStatsWriter_Int64PrimaryKey(t *testing.T) {
|
||||
data := &Int64FieldData{
|
||||
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
}
|
||||
sw := &StatsWriter{}
|
||||
err := sw.StatsInt64(common.RowIDField, true, data)
|
||||
err := sw.generatePrimaryKeyStats(common.RowIDField, schemapb.DataType_Int64, data)
|
||||
assert.NoError(t, err)
|
||||
b := sw.GetBuffer()
|
||||
|
||||
sr := &StatsReader{}
|
||||
sr.SetBuffer(b)
|
||||
stats, err := sr.GetInt64Stats()
|
||||
stats, err := sr.GetPrimaryKeyStats()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, stats.Max, int64(9))
|
||||
assert.Equal(t, stats.Min, int64(1))
|
||||
maxPk := &Int64PrimaryKey{
|
||||
Value: 9,
|
||||
}
|
||||
minPk := &Int64PrimaryKey{
|
||||
Value: 1,
|
||||
}
|
||||
assert.Equal(t, true, stats.MaxPk.EQ(maxPk))
|
||||
assert.Equal(t, true, stats.MinPk.EQ(minPk))
|
||||
buffer := make([]byte, 8)
|
||||
for _, id := range data {
|
||||
for _, id := range data.Data {
|
||||
common.Endian.PutUint64(buffer, uint64(id))
|
||||
assert.True(t, stats.BF.Test(buffer))
|
||||
}
|
||||
|
||||
msgs := []int64{}
|
||||
err = sw.StatsInt64(rootcoord.RowIDField, true, msgs)
|
||||
msgs := &Int64FieldData{
|
||||
Data: []int64{},
|
||||
}
|
||||
err = sw.generatePrimaryKeyStats(rootcoord.RowIDField, schemapb.DataType_Int64, msgs)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestStatsWriter_StringPrimaryKey(t *testing.T) {
|
||||
data := &StringFieldData{
|
||||
Data: []string{"bc", "ac", "abd", "cd", "milvus"},
|
||||
}
|
||||
sw := &StatsWriter{}
|
||||
err := sw.generatePrimaryKeyStats(common.RowIDField, schemapb.DataType_VarChar, data)
|
||||
assert.NoError(t, err)
|
||||
b := sw.GetBuffer()
|
||||
|
||||
sr := &StatsReader{}
|
||||
sr.SetBuffer(b)
|
||||
stats, err := sr.GetPrimaryKeyStats()
|
||||
assert.Nil(t, err)
|
||||
maxPk := &StringPrimaryKey{
|
||||
Value: "milvus",
|
||||
}
|
||||
minPk := &StringPrimaryKey{
|
||||
Value: "abd",
|
||||
}
|
||||
assert.Equal(t, true, stats.MaxPk.EQ(maxPk))
|
||||
assert.Equal(t, true, stats.MinPk.EQ(minPk))
|
||||
for _, id := range data.Data {
|
||||
assert.True(t, stats.BF.TestString(id))
|
||||
}
|
||||
|
||||
msgs := &Int64FieldData{
|
||||
Data: []int64{},
|
||||
}
|
||||
err = sw.generatePrimaryKeyStats(rootcoord.RowIDField, schemapb.DataType_Int64, msgs)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestStatsWriter_UpgradePrimaryKey(t *testing.T) {
|
||||
data := &Int64FieldData{
|
||||
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
}
|
||||
|
||||
stats := &PrimaryKeyStats{
|
||||
FieldID: common.RowIDField,
|
||||
Min: 1,
|
||||
Max: 9,
|
||||
BF: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
|
||||
}
|
||||
|
||||
b := make([]byte, 8)
|
||||
for _, int64Value := range data.Data {
|
||||
common.Endian.PutUint64(b, uint64(int64Value))
|
||||
stats.BF.Add(b)
|
||||
}
|
||||
blob, err := json.Marshal(stats)
|
||||
assert.Nil(t, err)
|
||||
sr := &StatsReader{}
|
||||
sr.SetBuffer(blob)
|
||||
unmarshaledStats, err := sr.GetPrimaryKeyStats()
|
||||
assert.Nil(t, err)
|
||||
maxPk := &Int64PrimaryKey{
|
||||
Value: 9,
|
||||
}
|
||||
minPk := &Int64PrimaryKey{
|
||||
Value: 1,
|
||||
}
|
||||
assert.Equal(t, true, unmarshaledStats.MaxPk.EQ(maxPk))
|
||||
assert.Equal(t, true, unmarshaledStats.MinPk.EQ(minPk))
|
||||
buffer := make([]byte, 8)
|
||||
for _, id := range data.Data {
|
||||
common.Endian.PutUint64(buffer, uint64(id))
|
||||
assert.True(t, unmarshaledStats.BF.Test(buffer))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -617,6 +617,16 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
|
|||
}
|
||||
fieldData.Data = append(fieldData.Data, srcData...)
|
||||
|
||||
idata.Data[field.FieldID] = fieldData
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
srcData := srcFields[field.FieldID].GetScalars().GetStringData().GetData()
|
||||
|
||||
fieldData := &StringFieldData{
|
||||
NumRows: []int64{int64(msg.NumRows)},
|
||||
Data: make([]string, 0, len(srcData)),
|
||||
}
|
||||
|
||||
fieldData.Data = append(fieldData.Data, srcData...)
|
||||
idata.Data[field.FieldID] = fieldData
|
||||
}
|
||||
}
|
||||
|
@ -812,7 +822,7 @@ func MergeInsertData(datas ...*InsertData) *InsertData {
|
|||
}
|
||||
|
||||
// TODO: string type.
|
||||
func GetPkFromInsertData(collSchema *schemapb.CollectionSchema, data *InsertData) ([]int64, error) {
|
||||
func GetPkFromInsertData(collSchema *schemapb.CollectionSchema, data *InsertData) (FieldData, error) {
|
||||
helper, err := typeutil.CreateSchemaHelper(collSchema)
|
||||
if err != nil {
|
||||
log.Error("failed to create schema helper", zap.Error(err))
|
||||
|
@ -831,13 +841,21 @@ func GetPkFromInsertData(collSchema *schemapb.CollectionSchema, data *InsertData
|
|||
return nil, errors.New("no primary field found in insert msg")
|
||||
}
|
||||
|
||||
realPfData, ok := pfData.(*Int64FieldData)
|
||||
var realPfData FieldData
|
||||
switch pf.DataType {
|
||||
case schemapb.DataType_Int64:
|
||||
realPfData, ok = pfData.(*Int64FieldData)
|
||||
case schemapb.DataType_VarChar:
|
||||
realPfData, ok = pfData.(*StringFieldData)
|
||||
default:
|
||||
//TODO
|
||||
}
|
||||
if !ok {
|
||||
log.Warn("primary field not in int64 format", zap.Int64("fieldID", pf.FieldID))
|
||||
return nil, errors.New("primary field not in int64 format")
|
||||
log.Warn("primary field not in Int64 or VarChar format", zap.Int64("fieldID", pf.FieldID))
|
||||
return nil, errors.New("primary field not in Int64 or VarChar format")
|
||||
}
|
||||
|
||||
return realPfData.Data, nil
|
||||
return realPfData, nil
|
||||
}
|
||||
|
||||
func boolFieldDataToPbBytes(field *BoolFieldData) ([]byte, error) {
|
||||
|
|
|
@ -1423,7 +1423,7 @@ func TestGetPkFromInsertData(t *testing.T) {
|
|||
}
|
||||
d, err := GetPkFromInsertData(pfSchema, realInt64Data)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int64{1, 2, 3}, d)
|
||||
assert.Equal(t, []int64{1, 2, 3}, d.(*Int64FieldData).Data)
|
||||
}
|
||||
|
||||
func Test_boolFieldDataToBytes(t *testing.T) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -265,3 +266,81 @@ func ConvertChannelName(chanName string, tokenFrom string, tokenTo string) (stri
|
|||
}
|
||||
return "", fmt.Errorf("cannot find token '%s' in '%s'", tokenFrom, chanName)
|
||||
}
|
||||
|
||||
func getNumRowsOfScalarField(datas interface{}) uint64 {
|
||||
realTypeDatas := reflect.ValueOf(datas)
|
||||
return uint64(realTypeDatas.Len())
|
||||
}
|
||||
|
||||
func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint64, error) {
|
||||
if dim <= 0 {
|
||||
return 0, fmt.Errorf("dim(%d) should be greater than 0", dim)
|
||||
}
|
||||
l := len(fDatas)
|
||||
if int64(l)%dim != 0 {
|
||||
return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim)
|
||||
}
|
||||
return uint64(int64(l) / dim), nil
|
||||
}
|
||||
|
||||
func getNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint64, error) {
|
||||
if dim <= 0 {
|
||||
return 0, fmt.Errorf("dim(%d) should be greater than 0", dim)
|
||||
}
|
||||
if dim%8 != 0 {
|
||||
return 0, fmt.Errorf("dim(%d) should divide 8", dim)
|
||||
}
|
||||
l := len(bDatas)
|
||||
if (8*int64(l))%dim != 0 {
|
||||
return 0, fmt.Errorf("the num(%d) of all bits should divide the dim(%d)", 8*l, dim)
|
||||
}
|
||||
return uint64((8 * int64(l)) / dim), nil
|
||||
}
|
||||
|
||||
// GetNumRowOfFieldData return num rows of the field data
|
||||
func GetNumRowOfFieldData(fieldData *schemapb.FieldData) (uint64, error) {
|
||||
var fieldNumRows uint64
|
||||
var err error
|
||||
switch fieldType := fieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := fieldData.GetScalars()
|
||||
switch scalarType := scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
fieldNumRows = getNumRowsOfScalarField(scalarField.GetBoolData().Data)
|
||||
case *schemapb.ScalarField_IntData:
|
||||
fieldNumRows = getNumRowsOfScalarField(scalarField.GetIntData().Data)
|
||||
case *schemapb.ScalarField_LongData:
|
||||
fieldNumRows = getNumRowsOfScalarField(scalarField.GetLongData().Data)
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
fieldNumRows = getNumRowsOfScalarField(scalarField.GetFloatData().Data)
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
fieldNumRows = getNumRowsOfScalarField(scalarField.GetDoubleData().Data)
|
||||
case *schemapb.ScalarField_StringData:
|
||||
fieldNumRows = getNumRowsOfScalarField(scalarField.GetStringData().Data)
|
||||
default:
|
||||
return 0, fmt.Errorf("%s is not supported now", scalarType)
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
vectorField := fieldData.GetVectors()
|
||||
switch vectorFieldType := vectorField.Data.(type) {
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
dim := vectorField.GetDim()
|
||||
fieldNumRows, err = getNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
dim := vectorField.GetDim()
|
||||
fieldNumRows, err = getNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%s is not supported now", vectorFieldType)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("%s is not supported now", fieldType)
|
||||
}
|
||||
|
||||
return fieldNumRows, nil
|
||||
}
|
||||
|
|
|
@ -314,3 +314,88 @@ func Test_ConvertChannelName(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
assert.Equal(t, deltaChanName, str)
|
||||
}
|
||||
|
||||
func TestGetNumRowsOfScalarField(t *testing.T) {
|
||||
cases := []struct {
|
||||
datas interface{}
|
||||
want uint64
|
||||
}{
|
||||
{[]bool{}, 0},
|
||||
{[]bool{true, false}, 2},
|
||||
{[]int32{}, 0},
|
||||
{[]int32{1, 2}, 2},
|
||||
{[]int64{}, 0},
|
||||
{[]int64{1, 2}, 2},
|
||||
{[]float32{}, 0},
|
||||
{[]float32{1.0, 2.0}, 2},
|
||||
{[]float64{}, 0},
|
||||
{[]float64{1.0, 2.0}, 2},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := getNumRowsOfScalarField(test.datas); got != test.want {
|
||||
t.Errorf("getNumRowsOfScalarField(%v) = %v", test.datas, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNumRowsOfFloatVectorField(t *testing.T) {
|
||||
cases := []struct {
|
||||
fDatas []float32
|
||||
dim int64
|
||||
want uint64
|
||||
errIsNil bool
|
||||
}{
|
||||
{[]float32{}, -1, 0, false}, // dim <= 0
|
||||
{[]float32{}, 0, 0, false}, // dim <= 0
|
||||
{[]float32{1.0}, 128, 0, false}, // length % dim != 0
|
||||
{[]float32{}, 128, 0, true},
|
||||
{[]float32{1.0, 2.0}, 2, 1, true},
|
||||
{[]float32{1.0, 2.0, 3.0, 4.0}, 2, 2, true},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
got, err := getNumRowsOfFloatVectorField(test.fDatas, test.dim)
|
||||
if test.errIsNil {
|
||||
assert.Equal(t, nil, err)
|
||||
if got != test.want {
|
||||
t.Errorf("getNumRowsOfFloatVectorField(%v, %v) = %v, %v", test.fDatas, test.dim, test.want, nil)
|
||||
}
|
||||
} else {
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNumRowsOfBinaryVectorField(t *testing.T) {
|
||||
cases := []struct {
|
||||
bDatas []byte
|
||||
dim int64
|
||||
want uint64
|
||||
errIsNil bool
|
||||
}{
|
||||
{[]byte{}, -1, 0, false}, // dim <= 0
|
||||
{[]byte{}, 0, 0, false}, // dim <= 0
|
||||
{[]byte{1.0}, 128, 0, false}, // length % dim != 0
|
||||
{[]byte{}, 128, 0, true},
|
||||
{[]byte{1.0}, 1, 0, false}, // dim % 8 != 0
|
||||
{[]byte{1.0}, 4, 0, false}, // dim % 8 != 0
|
||||
{[]byte{1.0, 2.0}, 8, 2, true},
|
||||
{[]byte{1.0, 2.0}, 16, 1, true},
|
||||
{[]byte{1.0, 2.0, 3.0, 4.0}, 8, 4, true},
|
||||
{[]byte{1.0, 2.0, 3.0, 4.0}, 16, 2, true},
|
||||
{[]byte{1.0}, 128, 0, false}, // (8*l) % dim != 0
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
got, err := getNumRowsOfBinaryVectorField(test.bDatas, test.dim)
|
||||
if test.errIsNil {
|
||||
assert.Equal(t, nil, err)
|
||||
if got != test.want {
|
||||
t.Errorf("getNumRowsOfBinaryVectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil)
|
||||
}
|
||||
} else {
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
|
@ -102,15 +103,29 @@ func writeToBuffer(w io.Writer, endian binary.ByteOrder, d interface{}) error {
|
|||
return binary.Write(w, endian, d)
|
||||
}
|
||||
|
||||
func TransferColumnBasedDataToRowBasedData(columns []*schemapb.FieldData) (rows []*commonpb.Blob, err error) {
|
||||
func TransferColumnBasedDataToRowBasedData(schema *schemapb.CollectionSchema, columns []*schemapb.FieldData) (rows []*commonpb.Blob, err error) {
|
||||
dTypes := make([]schemapb.DataType, 0, len(columns))
|
||||
datas := make([][]interface{}, 0, len(columns))
|
||||
rowNum := 0
|
||||
|
||||
fieldID2FieldData := make(map[int64]schemapb.FieldData)
|
||||
for _, field := range columns {
|
||||
switch field.Field.(type) {
|
||||
fieldID2FieldData[field.FieldId] = *field
|
||||
}
|
||||
|
||||
// reorder field data by schema field orider
|
||||
for _, field := range schema.Fields {
|
||||
if field.FieldID == common.RowIDField || field.FieldID == common.TimeStampField {
|
||||
continue
|
||||
}
|
||||
fieldData, ok := fieldID2FieldData[field.FieldID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("field %s data not exist", field.Name)
|
||||
}
|
||||
|
||||
switch fieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := field.GetScalars()
|
||||
scalarField := fieldData.GetScalars()
|
||||
switch scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
err := appendScalarField(&datas, &rowNum, func() interface{} {
|
||||
|
@ -157,7 +172,7 @@ func TransferColumnBasedDataToRowBasedData(columns []*schemapb.FieldData) (rows
|
|||
continue
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
vectorField := field.GetVectors()
|
||||
vectorField := fieldData.GetVectors()
|
||||
switch vectorField.Data.(type) {
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
floatVectorFieldData := vectorField.GetFloatVector().Data
|
||||
|
@ -184,8 +199,7 @@ func TransferColumnBasedDataToRowBasedData(columns []*schemapb.FieldData) (rows
|
|||
continue
|
||||
}
|
||||
|
||||
dTypes = append(dTypes, field.Type)
|
||||
|
||||
dTypes = append(dTypes, field.DataType)
|
||||
}
|
||||
|
||||
rows = make([]*commonpb.Blob, 0, rowNum)
|
||||
|
|
|
@ -24,13 +24,75 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
||||
fieldSchema := []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "bool_field",
|
||||
DataType: schemapb.DataType_Bool,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "int8_field",
|
||||
DataType: schemapb.DataType_Int8,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "int16_field",
|
||||
DataType: schemapb.DataType_Int16,
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "int32_field",
|
||||
DataType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: "int64_field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 105,
|
||||
Name: "float32_field",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
{
|
||||
FieldID: 106,
|
||||
Name: "float64_field",
|
||||
DataType: schemapb.DataType_Double,
|
||||
},
|
||||
{
|
||||
FieldID: 107,
|
||||
Name: "float_vector_field",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 108,
|
||||
Name: "binary_vector_field",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
columns := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Bool,
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Bool,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
|
@ -42,7 +104,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int8,
|
||||
FieldId: 101,
|
||||
Type: schemapb.DataType_Int8,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
|
@ -54,7 +117,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int16,
|
||||
FieldId: 102,
|
||||
Type: schemapb.DataType_Int16,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
|
@ -66,7 +130,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int32,
|
||||
FieldId: 103,
|
||||
Type: schemapb.DataType_Int32,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
|
@ -78,7 +143,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldId: 104,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
|
@ -90,7 +156,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldId: 105,
|
||||
Type: schemapb.DataType_Float,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
|
@ -102,7 +169,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Double,
|
||||
FieldId: 106,
|
||||
Type: schemapb.DataType_Double,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
|
@ -114,7 +182,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldId: 107,
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 1,
|
||||
|
@ -127,7 +196,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldId: 108,
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 8,
|
||||
|
@ -138,7 +208,7 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
rows, err := TransferColumnBasedDataToRowBasedData(columns)
|
||||
rows, err := TransferColumnBasedDataToRowBasedData(&schemapb.CollectionSchema{Fields: fieldSchema}, columns)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, len(rows))
|
||||
if common.Endian == binary.LittleEndian {
|
||||
|
|
|
@ -17,12 +17,17 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/spaolacci/murmur3"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
const substringLengthForCRC = 100
|
||||
|
||||
// Hash32Bytes hashing a byte array to uint32
|
||||
func Hash32Bytes(b []byte) (uint32, error) {
|
||||
h := murmur3.New32()
|
||||
|
@ -55,3 +60,37 @@ func Hash32String(s string) (int64, error) {
|
|||
}
|
||||
return int64(v), nil
|
||||
}
|
||||
|
||||
// HashString2Uint32 hashing a string to uint32
|
||||
func HashString2Uint32(v string) uint32 {
|
||||
subString := v
|
||||
if len(v) > substringLengthForCRC {
|
||||
subString = v[:substringLengthForCRC]
|
||||
}
|
||||
|
||||
return crc32.ChecksumIEEE([]byte(subString))
|
||||
}
|
||||
|
||||
// HashPK2Channels hash primary keys to channels
|
||||
func HashPK2Channels(primaryKeys *schemapb.IDs, shardNames []string) []uint32 {
|
||||
numShard := uint32(len(shardNames))
|
||||
var hashValues []uint32
|
||||
switch primaryKeys.IdField.(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
pks := primaryKeys.GetIntId().Data
|
||||
for _, pk := range pks {
|
||||
value, _ := Hash32Int64(pk)
|
||||
hashValues = append(hashValues, value%numShard)
|
||||
}
|
||||
case *schemapb.IDs_StrId:
|
||||
pks := primaryKeys.GetStrId().Data
|
||||
for _, pk := range pks {
|
||||
hash := HashString2Uint32(pk)
|
||||
hashValues = append(hashValues, hash%numShard)
|
||||
}
|
||||
default:
|
||||
//TODO::
|
||||
}
|
||||
|
||||
return hashValues
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -67,3 +68,29 @@ func TestHash32_String(t *testing.T) {
|
|||
|
||||
assert.Equal(t, uint32(h), h2)
|
||||
}
|
||||
|
||||
func TestHashPK2Channels(t *testing.T) {
|
||||
channels := []string{"test1", "test2"}
|
||||
int64IDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{100, 102, 102, 103, 104},
|
||||
},
|
||||
},
|
||||
}
|
||||
ret := HashPK2Channels(int64IDs, channels)
|
||||
assert.Equal(t, 5, len(ret))
|
||||
//same pk hash to same channel
|
||||
assert.Equal(t, ret[1], ret[2])
|
||||
|
||||
stringIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"ab", "bc", "bc", "abd", "milvus"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ret = HashPK2Channels(stringIDs, channels)
|
||||
assert.Equal(t, 5, len(ret))
|
||||
assert.Equal(t, ret[1], ret[2])
|
||||
}
|
||||
|
|
|
@ -100,6 +100,33 @@ func EstimateSizePerRecord(schema *schemapb.CollectionSchema) (int, error) {
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, error) {
|
||||
res := 0
|
||||
for _, fs := range fieldsData {
|
||||
switch fs.GetType() {
|
||||
case schemapb.DataType_Bool, schemapb.DataType_Int8:
|
||||
res++
|
||||
case schemapb.DataType_Int16:
|
||||
res += 2
|
||||
case schemapb.DataType_Int32, schemapb.DataType_Float:
|
||||
res += 4
|
||||
case schemapb.DataType_Int64, schemapb.DataType_Double:
|
||||
res += 8
|
||||
case schemapb.DataType_VarChar:
|
||||
if rowOffset >= len(fs.GetScalars().GetStringData().GetData()) {
|
||||
return 0, fmt.Errorf("offset out range of field datas")
|
||||
}
|
||||
//TODO:: check len(varChar) <= maxLengthPerRow
|
||||
res += len(fs.GetScalars().GetStringData().Data[rowOffset])
|
||||
case schemapb.DataType_BinaryVector:
|
||||
res += int(fs.GetVectors().GetDim())
|
||||
case schemapb.DataType_FloatVector:
|
||||
res += int(fs.GetVectors().GetDim() * 4)
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// SchemaHelper provides methods to get the schema of fields
|
||||
type SchemaHelper struct {
|
||||
schema *schemapb.CollectionSchema
|
||||
|
@ -288,6 +315,16 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i
|
|||
} else {
|
||||
dstScalar.GetDoubleData().Data = append(dstScalar.GetDoubleData().Data, srcScalar.DoubleData.Data[idx])
|
||||
}
|
||||
case *schemapb.ScalarField_StringData:
|
||||
if dstScalar.GetStringData() == nil {
|
||||
dstScalar.Data = &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{srcScalar.StringData.Data[idx]},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
dstScalar.GetStringData().Data = append(dstScalar.GetStringData().Data, srcScalar.StringData.Data[idx])
|
||||
}
|
||||
default:
|
||||
log.Error("Not supported field type", zap.String("field type", fieldData.Type.String()))
|
||||
}
|
||||
|
@ -337,17 +374,13 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i
|
|||
}
|
||||
}
|
||||
|
||||
func FillFieldBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error {
|
||||
if len(columns) != len(schema.GetFields()) {
|
||||
return fmt.Errorf("len(columns) mismatch the len(fields), len(columns): %d, len(fields): %d",
|
||||
len(columns), len(schema.GetFields()))
|
||||
// GetPrimaryFieldSchema get primary field schema from collection schema
|
||||
func GetPrimaryFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.FieldSchema, error) {
|
||||
for _, fieldSchema := range schema.Fields {
|
||||
if fieldSchema.IsPrimaryKey {
|
||||
return fieldSchema, nil
|
||||
}
|
||||
}
|
||||
|
||||
for idx, f := range schema.GetFields() {
|
||||
columns[idx].FieldName = f.Name
|
||||
columns[idx].Type = f.DataType
|
||||
columns[idx].FieldId = f.FieldID
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, errors.New("primary field is not found")
|
||||
}
|
||||
|
|
|
@ -512,29 +512,29 @@ func TestAppendFieldData(t *testing.T) {
|
|||
assert.Equal(t, FloatVector, result[6].GetVectors().GetFloatVector().Data)
|
||||
}
|
||||
|
||||
func TestFillFieldBySchema(t *testing.T) {
|
||||
columns := []*schemapb.FieldData{
|
||||
{},
|
||||
func TestGetPrimaryFieldSchema(t *testing.T) {
|
||||
int64Field := &schemapb.FieldSchema{
|
||||
FieldID: 1,
|
||||
Name: "int64Field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{}
|
||||
// length mismatch
|
||||
assert.Error(t, FillFieldBySchema(columns, schema))
|
||||
columns = []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 0,
|
||||
},
|
||||
|
||||
floatField := &schemapb.FieldSchema{
|
||||
FieldID: 2,
|
||||
Name: "floatField",
|
||||
DataType: schemapb.DataType_Float,
|
||||
}
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "TestFillFieldIDBySchema",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
FieldID: 1,
|
||||
},
|
||||
},
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{int64Field, floatField},
|
||||
}
|
||||
assert.NoError(t, FillFieldBySchema(columns, schema))
|
||||
assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Int64, columns[0].Type)
|
||||
assert.Equal(t, int64(1), columns[0].FieldId)
|
||||
|
||||
// no primary field error
|
||||
_, err := GetPrimaryFieldSchema(schema)
|
||||
assert.Error(t, err)
|
||||
|
||||
int64Field.IsPrimaryKey = true
|
||||
primaryField, err := GetPrimaryFieldSchema(schema)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, schemapb.DataType_Int64, primaryField.DataType)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue