diff --git a/go.sum b/go.sum index e3592790ae..94ec46ed2d 100644 --- a/go.sum +++ b/go.sum @@ -326,6 +326,12 @@ github.com/jarcoal/httpmock v1.0.8 h1:8kI16SoO6LQKgPE7PvQuV+YuD/inwHd7fOOe2zMbo4 github.com/jarcoal/httpmock v1.0.8/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= github.com/jawher/mow.cli v1.0.4/go.mod h1:5hQj2V8g+qYmLUVWqu4Wuja1pI57M83EChYLVZ0sMKk= github.com/jawher/mow.cli v1.2.0/go.mod h1:y+pcA3jBAdo/GIZx/0rFjw/K2bVEODP9rfZOfaiq8Ko= +github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= +github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= +github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= +github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E= +github.com/jhump/protoreflect v1.12.0 h1:1NQ4FpWMgn3by/n1X0fbeKEUxP1wBt7+Oitpv01HR10= +github.com/jhump/protoreflect v1.12.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= diff --git a/internal/datanode/binlog_io_test.go b/internal/datanode/binlog_io_test.go index 999373e5d9..1c93c6fd00 100644 --- a/internal/datanode/binlog_io_test.go +++ b/internal/datanode/binlog_io_test.go @@ -25,7 +25,9 @@ import ( "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -41,7 +43,7 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { b := &binlogIO{cm, alloc} t.Run("Test upload", func(t *testing.T) { f := &MetaFactory{} - meta := f.GetCollectionMeta(UniqueID(10001), "uploads") + meta := f.GetCollectionMeta(UniqueID(10001), "uploads", schemapb.DataType_Int64) iData := genInsertData() dData := &DeleteData{ @@ -52,16 +54,16 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { p, err := b.upload(context.TODO(), 1, 10, []*InsertData{iData}, dData, meta) assert.NoError(t, err) - assert.Equal(t, 11, len(p.inPaths)) - assert.Equal(t, 3, len(p.statsPaths)) + assert.Equal(t, 12, len(p.inPaths)) + assert.Equal(t, 1, len(p.statsPaths)) assert.Equal(t, 1, len(p.inPaths[0].GetBinlogs())) assert.Equal(t, 1, len(p.statsPaths[0].GetBinlogs())) assert.NotNil(t, p.deltaInfo) p, err = b.upload(context.TODO(), 1, 10, []*InsertData{iData, iData}, dData, meta) assert.NoError(t, err) - assert.Equal(t, 11, len(p.inPaths)) - assert.Equal(t, 3, len(p.statsPaths)) + assert.Equal(t, 12, len(p.inPaths)) + assert.Equal(t, 1, len(p.statsPaths)) assert.Equal(t, 2, len(p.inPaths[0].GetBinlogs())) assert.Equal(t, 2, len(p.statsPaths[0].GetBinlogs())) assert.NotNil(t, p.deltaInfo) @@ -76,7 +78,7 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { t.Run("Test upload error", func(t *testing.T) { f := &MetaFactory{} - meta := f.GetCollectionMeta(UniqueID(10001), "uploads") + meta := f.GetCollectionMeta(UniqueID(10001), "uploads", schemapb.DataType_Int64) dData := &DeleteData{ Pks: []int64{}, Tss: []uint64{}, @@ -197,7 +199,7 @@ func TestBinlogIOInnerMethods(t *testing.T) { t.Run("Test genDeltaBlobs", func(t *testing.T) { f := &MetaFactory{} - meta := f.GetCollectionMeta(UniqueID(10002), "test_gen_blobs") + meta := f.GetCollectionMeta(UniqueID(10002), "test_gen_blobs", schemapb.DataType_Int64) tests := []struct { isvalid bool @@ -246,19 +248,36 @@ func TestBinlogIOInnerMethods(t *testing.T) { t.Run("Test genInsertBlobs", func(t *testing.T) { f := &MetaFactory{} - meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs") + tests := []struct { + pkType schemapb.DataType + description string + }{ + {schemapb.DataType_Int64, "int64PrimaryField"}, + {schemapb.DataType_VarChar, "varCharPrimaryField"}, + } - kvs, pin, pstats, err := b.genInsertBlobs(genInsertData(), 10, 1, meta) + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs", test.pkType) + helper, err := typeutil.CreateSchemaHelper(meta.Schema) + assert.NoError(t, err) + primaryKeyFieldSchema, err := helper.GetPrimaryKeyField() + assert.NoError(t, err) + primaryKeyFieldID := primaryKeyFieldSchema.GetFieldID() - assert.NoError(t, err) - assert.Equal(t, 3, len(pstats)) - assert.Equal(t, 11, len(pin)) - assert.Equal(t, 14, len(kvs)) + kvs, pin, pstats, err := b.genInsertBlobs(genInsertData(), 10, 1, meta) - log.Debug("test paths", - zap.Any("kvs no.", len(kvs)), - zap.String("insert paths field0", pin[common.TimeStampField].GetBinlogs()[0].GetLogPath()), - zap.String("stats paths field0", pstats[common.TimeStampField].GetBinlogs()[0].GetLogPath())) + assert.NoError(t, err) + assert.Equal(t, 1, len(pstats)) + assert.Equal(t, 12, len(pin)) + assert.Equal(t, 13, len(kvs)) + + log.Debug("test paths", + zap.Any("kvs no.", len(kvs)), + zap.String("insert paths field0", pin[common.TimeStampField].GetBinlogs()[0].GetLogPath()), + zap.String("stats paths field0", pstats[primaryKeyFieldID].GetBinlogs()[0].GetLogPath())) + }) + } }) t.Run("Test genInsertBlobs error", func(t *testing.T) { @@ -269,7 +288,7 @@ func TestBinlogIOInnerMethods(t *testing.T) { assert.Empty(t, pstats) f := &MetaFactory{} - meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs") + meta := f.GetCollectionMeta(UniqueID(10001), "test_gen_blobs", schemapb.DataType_Int64) kvs, pin, pstats, err = b.genInsertBlobs(genEmptyInsertData(), 10, 1, meta) assert.Error(t, err) diff --git a/internal/datanode/compactor.go b/internal/datanode/compactor.go index 97d734555f..a75061cd91 100644 --- a/internal/datanode/compactor.go +++ b/internal/datanode/compactor.go @@ -527,7 +527,11 @@ func (t *compactionTask) compact() error { } // no need to shorten the PK range of a segment, deleting dup PKs is valid } else { - t.mergeFlushedSegments(targetSegID, collID, partID, t.plan.GetPlanID(), segIDs, t.plan.GetChannel(), numRows) + err = t.mergeFlushedSegments(targetSegID, collID, partID, t.plan.GetPlanID(), segIDs, t.plan.GetChannel(), numRows) + if err != nil { + log.Error("compact wrong", zap.Int64("planID", t.plan.GetPlanID()), zap.Error(err)) + return err + } } uninjectStart := time.Now() @@ -660,6 +664,21 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} } rst = data + case schemapb.DataType_String, schemapb.DataType_VarChar: + var data = &storage.StringFieldData{ + NumRows: numOfRows, + Data: make([]string, 0, len(content)), + } + + for _, c := range content { + r, ok := c.(string) + if !ok { + return nil, errTransferType + } + data.Data = append(data.Data, r) + } + rst = data + case schemapb.DataType_FloatVector: var data = &storage.FloatVectorFieldData{ NumRows: numOfRows, diff --git a/internal/datanode/compactor_test.go b/internal/datanode/compactor_test.go index edf1de9648..3e16c64957 100644 --- a/internal/datanode/compactor_test.go +++ b/internal/datanode/compactor_test.go @@ -38,7 +38,9 @@ func TestCompactionTaskInnerMethods(t *testing.T) { cm := storage.NewLocalChunkManager(storage.RootPath(compactTestDir)) defer cm.RemoveWithPrefix("") t.Run("Test getSegmentMeta", func(t *testing.T) { - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } replica, err := newReplica(context.TODO(), rc, cm, 1) require.NoError(t, err) @@ -80,6 +82,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { {true, schemapb.DataType_Int64, []interface{}{int64(1), int64(2)}, "valid int64"}, {true, schemapb.DataType_Float, []interface{}{float32(1), float32(2)}, "valid float32"}, {true, schemapb.DataType_Double, []interface{}{float64(1), float64(2)}, "valid float64"}, + {true, schemapb.DataType_VarChar, []interface{}{"test1", "test2"}, "valid varChar"}, {true, schemapb.DataType_FloatVector, []interface{}{[]float32{1.0, 2.0}}, "valid floatvector"}, {true, schemapb.DataType_BinaryVector, []interface{}{[]byte{255}}, "valid binaryvector"}, {false, schemapb.DataType_Bool, []interface{}{1, 2}, "invalid bool"}, @@ -89,9 +92,10 @@ func TestCompactionTaskInnerMethods(t *testing.T) { {false, schemapb.DataType_Int64, []interface{}{nil, nil}, "invalid int64"}, {false, schemapb.DataType_Float, []interface{}{nil, nil}, "invalid float32"}, {false, schemapb.DataType_Double, []interface{}{nil, nil}, "invalid float64"}, + {false, schemapb.DataType_VarChar, []interface{}{nil, nil}, "invalid varChar"}, {false, schemapb.DataType_FloatVector, []interface{}{nil, nil}, "invalid floatvector"}, {false, schemapb.DataType_BinaryVector, []interface{}{nil, nil}, "invalid binaryvector"}, - {false, schemapb.DataType_String, nil, "invalid data type"}, + {false, schemapb.DataType_None, nil, "invalid data type"}, } for _, test := range tests { @@ -243,7 +247,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { t.Run("Merge without expiration", func(t *testing.T) { Params.DataCoordCfg.CompactionEntityExpiration = math.MaxInt64 iData := genInsertDataWithExpiredTS() - meta := NewMetaFactory().GetCollectionMeta(1, "test") + meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) iblobs, err := getInsertBlobs(100, iData, meta) require.NoError(t, err) @@ -272,7 +276,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }() Params.DataNodeCfg.FlushInsertBufferSize = 128 iData := genInsertDataWithExpiredTS() - meta := NewMetaFactory().GetCollectionMeta(1, "test") + meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) iblobs, err := getInsertBlobs(100, iData, meta) require.NoError(t, err) @@ -295,7 +299,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { t.Run("Merge with expiration", func(t *testing.T) { Params.DataCoordCfg.CompactionEntityExpiration = 864000 // 10 days in seconds iData := genInsertDataWithExpiredTS() - meta := NewMetaFactory().GetCollectionMeta(1, "test") + meta := NewMetaFactory().GetCollectionMeta(1, "test", schemapb.DataType_Int64) iblobs, err := getInsertBlobs(100, iData, meta) require.NoError(t, err) @@ -448,16 +452,18 @@ func TestCompactorInterfaceMethods(t *testing.T) { var collID, partID, segID UniqueID = 1, 10, 100 alloc := NewAllocatorFactory(1) - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } dc := &DataCoordFactory{} mockfm := &mockFlushManager{} mockbIO := &binlogIO{cm, alloc} replica, err := newReplica(context.TODO(), rc, cm, collID) require.NoError(t, err) - replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, []UniqueID{1, 2}) + replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1, 2}}) iData := genInsertData() - meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name") + meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64) dData := &DeleteData{ Pks: []UniqueID{1}, Tss: []Timestamp{20000}, @@ -466,7 +472,7 @@ func TestCompactorInterfaceMethods(t *testing.T) { cpaths, err := mockbIO.upload(context.TODO(), segID, partID, []*InsertData{iData}, dData, meta) require.NoError(t, err) - require.Equal(t, 11, len(cpaths.inPaths)) + require.Equal(t, 12, len(cpaths.inPaths)) segBinlogs := []*datapb.CompactionSegmentBinlogs{ { SegmentID: segID, @@ -524,7 +530,7 @@ func TestCompactorInterfaceMethods(t *testing.T) { assert.False(t, replica.hasSegment(segID, true)) // re-add the segment - replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, []UniqueID{1, 2}) + replica.addFlushedSegmentWithPKs(segID, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1, 2}}) // Compact empty segment err = cm.RemoveWithPrefix("/") @@ -575,7 +581,9 @@ func TestCompactorInterfaceMethods(t *testing.T) { var collID, partID, segID1, segID2 UniqueID = 1, 10, 200, 201 alloc := NewAllocatorFactory(1) - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } dc := &DataCoordFactory{} mockfm := &mockFlushManager{} mockKv := memkv.NewMemoryKV() @@ -583,12 +591,12 @@ func TestCompactorInterfaceMethods(t *testing.T) { replica, err := newReplica(context.TODO(), rc, cm, collID) require.NoError(t, err) - replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1}) - replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{9}) + replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}}) + replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{9}}) require.True(t, replica.hasSegment(segID1, true)) require.True(t, replica.hasSegment(segID2, true)) - meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name") + meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64) iData1 := genInsertDataWithPKs([2]int64{1, 2}) dData1 := &DeleteData{ Pks: []UniqueID{1}, @@ -604,11 +612,11 @@ func TestCompactorInterfaceMethods(t *testing.T) { cpaths1, err := mockbIO.upload(context.TODO(), segID1, partID, []*InsertData{iData1}, dData1, meta) require.NoError(t, err) - require.Equal(t, 11, len(cpaths1.inPaths)) + require.Equal(t, 12, len(cpaths1.inPaths)) cpaths2, err := mockbIO.upload(context.TODO(), segID2, partID, []*InsertData{iData2}, dData2, meta) require.NoError(t, err) - require.Equal(t, 11, len(cpaths2.inPaths)) + require.Equal(t, 12, len(cpaths2.inPaths)) plan := &datapb.CompactionPlan{ PlanID: 10080, @@ -652,8 +660,8 @@ func TestCompactorInterfaceMethods(t *testing.T) { plan.PlanID++ plan.Timetravel = Timestamp(25000) - replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1}) - replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{9}) + replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}}) + replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{9}}) replica.removeSegments(19530) require.True(t, replica.hasSegment(segID1, true)) require.True(t, replica.hasSegment(segID2, true)) @@ -676,8 +684,8 @@ func TestCompactorInterfaceMethods(t *testing.T) { plan.PlanID++ plan.Timetravel = Timestamp(10000) - replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1}) - replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{9}) + replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}}) + replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{9}}) replica.removeSegments(19530) require.True(t, replica.hasSegment(segID1, true)) require.True(t, replica.hasSegment(segID2, true)) @@ -701,19 +709,21 @@ func TestCompactorInterfaceMethods(t *testing.T) { var collID, partID, segID1, segID2 UniqueID = 1, 10, 200, 201 alloc := NewAllocatorFactory(1) - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } dc := &DataCoordFactory{} mockfm := &mockFlushManager{} mockbIO := &binlogIO{cm, alloc} replica, err := newReplica(context.TODO(), rc, cm, collID) require.NoError(t, err) - replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, []UniqueID{1}) - replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, []UniqueID{1}) + replica.addFlushedSegmentWithPKs(segID1, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}}) + replica.addFlushedSegmentWithPKs(segID2, collID, partID, "channelname", 2, &storage.Int64FieldData{Data: []UniqueID{1}}) require.True(t, replica.hasSegment(segID1, true)) require.True(t, replica.hasSegment(segID2, true)) - meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name") + meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64) // the same pk for segmentI and segmentII iData1 := genInsertDataWithPKs([2]int64{1, 2}) iData2 := genInsertDataWithPKs([2]int64{1, 2}) @@ -732,11 +742,11 @@ func TestCompactorInterfaceMethods(t *testing.T) { cpaths1, err := mockbIO.upload(context.TODO(), segID1, partID, []*InsertData{iData1}, dData1, meta) require.NoError(t, err) - require.Equal(t, 11, len(cpaths1.inPaths)) + require.Equal(t, 12, len(cpaths1.inPaths)) cpaths2, err := mockbIO.upload(context.TODO(), segID2, partID, []*InsertData{iData2}, dData2, meta) require.NoError(t, err) - require.Equal(t, 11, len(cpaths2.inPaths)) + require.Equal(t, 12, len(cpaths2.inPaths)) plan := &datapb.CompactionPlan{ PlanID: 20080, diff --git a/internal/datanode/data_node_test.go b/internal/datanode/data_node_test.go index 69ead01bca..afaa7398e0 100644 --- a/internal/datanode/data_node_test.go +++ b/internal/datanode/data_node_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/metricsinfo" @@ -68,7 +69,7 @@ func TestDataNode(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - node := newIDLEDataNodeMock(ctx) + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) etcdCli, err := etcd.GetEtcdClient(&Params.EtcdCfg) assert.Nil(t, err) defer etcdCli.Close() @@ -141,7 +142,7 @@ func TestDataNode(t *testing.T) { t.Run("Test FlushSegments", func(t *testing.T) { dmChannelName := "fake-by-dev-rootcoord-dml-channel-test-FlushSegments" - node1 := newIDLEDataNodeMock(context.TODO()) + node1 := newIDLEDataNodeMock(context.TODO(), schemapb.DataType_Int64) node1.SetEtcdClient(etcdCli) err = node1.Init() assert.Nil(t, err) @@ -326,7 +327,7 @@ func TestDataNode(t *testing.T) { t.Run("Test BackGroundGC", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - node := newIDLEDataNodeMock(ctx) + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) vchanNameCh := make(chan string) node.clearSignal = vchanNameCh @@ -355,7 +356,7 @@ func TestDataNode(t *testing.T) { func TestWatchChannel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - node := newIDLEDataNodeMock(ctx) + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) etcdCli, err := etcd.GetEtcdClient(&Params.EtcdCfg) assert.Nil(t, err) defer etcdCli.Close() diff --git a/internal/datanode/data_sync_service_test.go b/internal/datanode/data_sync_service_test.go index ca0082a02e..fa40f77873 100644 --- a/internal/datanode/data_sync_service_test.go +++ b/internal/datanode/data_sync_service_test.go @@ -24,7 +24,6 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/internal/storage" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -34,6 +33,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" ) var dataSyncServiceTestDir = "/tmp/milvus_test/data_sync_service" @@ -136,7 +137,7 @@ func TestDataSyncService_newDataSyncService(te *testing.T) { te.Run(test.description, func(t *testing.T) { df := &DataCoordFactory{} - replica, err := newReplica(context.Background(), &RootCoordFactory{}, cm, test.collID) + replica, err := newReplica(context.Background(), &RootCoordFactory{pkType: schemapb.DataType_Int64}, cm, test.collID) assert.Nil(t, err) if test.replicaNil { replica = nil @@ -183,8 +184,10 @@ func TestDataSyncService_Start(t *testing.T) { // init data node Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1") - mockRootCoord := &RootCoordFactory{} + collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + mockRootCoord := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } collectionID := UniqueID(1) flushChan := make(chan flushMsg, 100) diff --git a/internal/datanode/flow_graph_insert_buffer_node.go b/internal/datanode/flow_graph_insert_buffer_node.go index a9d385e8fc..061a19b44d 100644 --- a/internal/datanode/flow_graph_insert_buffer_node.go +++ b/internal/datanode/flow_graph_insert_buffer_node.go @@ -127,6 +127,7 @@ func newBufferData(dimension int64) (*BufferData, error) { limit := Params.DataNodeCfg.FlushInsertBufferSize / (dimension * 4) + //TODO::xige-16 eval vec and string field return &BufferData{&InsertData{Data: make(map[UniqueID]storage.FieldData)}, 0, limit}, nil } @@ -417,8 +418,8 @@ func (ibNode *insertBufferNode) updateSegStatesInReplica(insertMsgs []*msgstream // 1.3 Put back into buffer // 1.4 Update related statistics func (ibNode *insertBufferNode) bufferInsertMsg(msg *msgstream.InsertMsg, endPos *internalpb.MsgPosition) error { - if !msg.CheckAligned() { - return errors.New("misaligned messages detected") + if err := msg.CheckAligned(); err != nil { + return err } currentSegID := msg.GetSegmentID() collectionID := msg.GetCollectionID() diff --git a/internal/datanode/flow_graph_insert_buffer_node_test.go b/internal/datanode/flow_graph_insert_buffer_node_test.go index 1c96363d58..2579488aa9 100644 --- a/internal/datanode/flow_graph_insert_buffer_node_test.go +++ b/internal/datanode/flow_graph_insert_buffer_node_test.go @@ -68,8 +68,10 @@ func TestFlowGraphInsertBufferNodeCreate(t *testing.T) { Params.EtcdCfg.MetaRootPath = testPath Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1") - mockRootCoord := &RootCoordFactory{} + collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + mockRootCoord := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID) assert.Nil(t, err) @@ -154,8 +156,10 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) { Params.EtcdCfg.MetaRootPath = testPath Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1") - mockRootCoord := &RootCoordFactory{} + collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + mockRootCoord := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID) assert.Nil(t, err) @@ -349,10 +353,12 @@ func TestFlowGraphInsertBufferNode_AutoFlush(t *testing.T) { Params.EtcdCfg.MetaRootPath = testPath Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1") + collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) dataFactory := NewDataFactory() - mockRootCoord := &RootCoordFactory{} + mockRootCoord := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } colRep := &SegmentReplica{ collectionID: collMeta.ID, @@ -597,7 +603,7 @@ type CompactedRootCoord struct { } func (m *CompactedRootCoord) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - if in.GetTimeStamp() <= m.compactTs { + if in.TimeStamp != 0 && in.GetTimeStamp() <= m.compactTs { return &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -620,50 +626,62 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { Params.EtcdCfg.MetaRootPath = testPath Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1") - - rcf := &RootCoordFactory{} - mockRootCoord := &CompactedRootCoord{ - RootCoord: rcf, - compactTs: 100, + tests := []struct { + collID UniqueID + pkType schemapb.DataType + description string + }{ + {0, schemapb.DataType_Int64, "int64PrimaryData"}, + {0, schemapb.DataType_VarChar, "varCharPrimaryData"}, } cm := storage.NewLocalChunkManager(storage.RootPath(insertNodeTestDir)) defer cm.RemoveWithPrefix("") - replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID) - assert.Nil(t, err) + for _, test := range tests { + collMeta := Factory.GetCollectionMeta(test.collID, "collection", test.pkType) + rcf := &RootCoordFactory{ + pkType: test.pkType, + } + mockRootCoord := &CompactedRootCoord{ + RootCoord: rcf, + compactTs: 100, + } - err = replica.addNewSegment(1, collMeta.ID, 0, insertChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{}) - require.NoError(t, err) - - msFactory := msgstream.NewPmsFactory() - err = msFactory.Init(&Params) - assert.Nil(t, err) - - fm := NewRendezvousFlushManager(&allocator{}, cm, replica, func(*segmentFlushPack) {}, emptyFlushAndDropFunc) - - flushChan := make(chan flushMsg, 100) - c := &nodeConfig{ - replica: replica, - msFactory: msFactory, - allocator: NewAllocatorFactory(), - vChannelName: "string", - } - iBNode, err := newInsertBufferNode(ctx, collMeta.ID, flushChan, fm, newCache(), c) - require.NoError(t, err) - - inMsg := genFlowGraphInsertMsg(insertChannelName) - for _, msg := range inMsg.insertMessages { - msg.EndTimestamp = 101 // ts valid - err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{}) + replica, err := newReplica(ctx, mockRootCoord, cm, collMeta.ID) assert.Nil(t, err) - } - for _, msg := range inMsg.insertMessages { - msg.EndTimestamp = 101 // ts valid - msg.RowIDs = []int64{} //misaligned data - err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{}) - assert.NotNil(t, err) + err = replica.addNewSegment(1, collMeta.ID, 0, insertChannelName, &internalpb.MsgPosition{}, &internalpb.MsgPosition{Timestamp: 101}) + require.NoError(t, err) + + msFactory := msgstream.NewPmsFactory() + err = msFactory.Init(&Params) + assert.Nil(t, err) + + fm := NewRendezvousFlushManager(&allocator{}, cm, replica, func(*segmentFlushPack) {}, emptyFlushAndDropFunc) + + flushChan := make(chan flushMsg, 100) + c := &nodeConfig{ + replica: replica, + msFactory: msFactory, + allocator: NewAllocatorFactory(), + vChannelName: "string", + } + iBNode, err := newInsertBufferNode(ctx, collMeta.ID, flushChan, fm, newCache(), c) + require.NoError(t, err) + + inMsg := genFlowGraphInsertMsg(insertChannelName) + for _, msg := range inMsg.insertMessages { + msg.EndTimestamp = 101 // ts valid + err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{}) + assert.Nil(t, err) + } + + for _, msg := range inMsg.insertMessages { + msg.EndTimestamp = 101 // ts valid + msg.RowIDs = []int64{} //misaligned data + err = iBNode.bufferInsertMsg(msg, &internalpb.MsgPosition{}) + assert.NotNil(t, err) + } } } @@ -681,7 +699,7 @@ func TestInsertBufferNode_updateSegStatesInReplica(te *testing.T) { } for _, test := range invalideTests { - replica, err := newReplica(context.Background(), &RootCoordFactory{}, cm, test.replicaCollID) + replica, err := newReplica(context.Background(), &RootCoordFactory{pkType: schemapb.DataType_Int64}, cm, test.replicaCollID) assert.Nil(te, err) ibNode := &insertBufferNode{ diff --git a/internal/datanode/flow_graph_manager_test.go b/internal/datanode/flow_graph_manager_test.go index 841280438c..75de8d3a81 100644 --- a/internal/datanode/flow_graph_manager_test.go +++ b/internal/datanode/flow_graph_manager_test.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/stretchr/testify/assert" @@ -36,7 +37,7 @@ func TestFlowGraphManager(t *testing.T) { assert.Nil(t, err) defer etcdCli.Close() - node := newIDLEDataNodeMock(ctx) + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) node.SetEtcdClient(etcdCli) err = node.Init() require.Nil(t, err) diff --git a/internal/datanode/flush_manager_test.go b/internal/datanode/flush_manager_test.go index 94e43a7d41..5459d451c0 100644 --- a/internal/datanode/flush_manager_test.go +++ b/internal/datanode/flush_manager_test.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/retry" "github.com/stretchr/testify/assert" @@ -516,7 +517,9 @@ func TestRendezvousFlushManager_close(t *testing.T) { func TestFlushNotifyFunc(t *testing.T) { ctx := context.Background() - rcf := &RootCoordFactory{} + rcf := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } cm := storage.NewLocalChunkManager(storage.RootPath(flushTestDir)) replica, err := newReplica(ctx, rcf, cm, 1) @@ -568,7 +571,10 @@ func TestFlushNotifyFunc(t *testing.T) { func TestDropVirtualChannelFunc(t *testing.T) { ctx := context.Background() - rcf := &RootCoordFactory{} + rcf := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } + cm := storage.NewLocalChunkManager(storage.RootPath(flushTestDir)) replica, err := newReplica(ctx, rcf, cm, 1) require.NoError(t, err) diff --git a/internal/datanode/meta_service_test.go b/internal/datanode/meta_service_test.go index 6fc7f0e769..ee66afefb6 100644 --- a/internal/datanode/meta_service_test.go +++ b/internal/datanode/meta_service_test.go @@ -23,6 +23,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/stretchr/testify/assert" ) @@ -37,7 +38,9 @@ func TestMetaService_All(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mFactory := &RootCoordFactory{} + mFactory := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } mFactory.setCollectionID(collectionID0) mFactory.setCollectionName(collectionName0) ms := newMetaService(mFactory, collectionID0) @@ -52,7 +55,7 @@ func TestMetaService_All(t *testing.T) { t.Run("Test printCollectionStruct", func(t *testing.T) { mf := &MetaFactory{} - collectionMeta := mf.GetCollectionMeta(collectionID0, collectionName0) + collectionMeta := mf.GetCollectionMeta(collectionID0, collectionName0, schemapb.DataType_Int64) printCollectionStruct(collectionMeta) }) } diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index 995c0cb6e9..684c2f240d 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -52,7 +52,7 @@ const debug = false var emptyFlushAndDropFunc flushAndDropFunc = func(_ []*segmentFlushPack) {} -func newIDLEDataNodeMock(ctx context.Context) *DataNode { +func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode { msFactory := msgstream.NewRmsFactory() node := NewDataNode(ctx, msFactory) @@ -60,6 +60,7 @@ func newIDLEDataNodeMock(ctx context.Context) *DataNode { ID: 0, collectionID: 1, collectionName: "collection-1", + pkType: pkType, } node.rootCoord = rc @@ -146,7 +147,8 @@ func NewMetaFactory() *MetaFactory { } type DataFactory struct { - rawData []byte + rawData []byte + columnData []*schemapb.FieldData } type RootCoordFactory struct { @@ -154,6 +156,7 @@ type RootCoordFactory struct { ID UniqueID collectionName string collectionID UniqueID + pkType schemapb.DataType } type DataCoordFactory struct { @@ -209,144 +212,166 @@ func (ds *DataCoordFactory) DropVirtualChannel(ctx context.Context, req *datapb. }, nil } -func (mf *MetaFactory) GetCollectionMeta(collectionID UniqueID, collectionName string) *etcdpb.CollectionMeta { +func (mf *MetaFactory) GetCollectionMeta(collectionID UniqueID, collectionName string, pkDataType schemapb.DataType) *etcdpb.CollectionMeta { sch := schemapb.CollectionSchema{ Name: collectionName, Description: "test collection by meta factory", AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 0, - Name: "RowID", - Description: "RowID field", - DataType: schemapb.DataType_Int64, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "f0_tk1", - Value: "f0_tv1", - }, - }, - }, - { - FieldID: 1, - Name: "Timestamp", - Description: "Timestamp field", - DataType: schemapb.DataType_Int64, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "f1_tk1", - Value: "f1_tv1", - }, - }, - }, - { - FieldID: 100, - Name: "float_vector_field", - Description: "field 100", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "2", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "indexkey", - Value: "indexvalue", - }, - }, - }, - { - FieldID: 101, - Name: "binary_vector_field", - Description: "field 101", - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "32", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "indexkey", - Value: "indexvalue", - }, - }, - }, - { - FieldID: 102, - Name: "bool_field", - Description: "field 102", - DataType: schemapb.DataType_Bool, - TypeParams: []*commonpb.KeyValuePair{}, - IndexParams: []*commonpb.KeyValuePair{}, - }, - { - FieldID: 103, - Name: "int8_field", - Description: "field 103", - DataType: schemapb.DataType_Int8, - TypeParams: []*commonpb.KeyValuePair{}, - IndexParams: []*commonpb.KeyValuePair{}, - }, - { - FieldID: 104, - Name: "int16_field", - Description: "field 104", - DataType: schemapb.DataType_Int16, - TypeParams: []*commonpb.KeyValuePair{}, - IndexParams: []*commonpb.KeyValuePair{}, - }, - { - FieldID: 105, - Name: "int32_field", - Description: "field 105", - DataType: schemapb.DataType_Int32, - TypeParams: []*commonpb.KeyValuePair{}, - IndexParams: []*commonpb.KeyValuePair{}, - }, - { - FieldID: 106, - Name: "int64_field", - Description: "field 106", - DataType: schemapb.DataType_Int64, - TypeParams: []*commonpb.KeyValuePair{}, - IndexParams: []*commonpb.KeyValuePair{}, - IsPrimaryKey: true, - }, - { - FieldID: 107, - Name: "float32_field", - Description: "field 107", - DataType: schemapb.DataType_Float, - TypeParams: []*commonpb.KeyValuePair{}, - IndexParams: []*commonpb.KeyValuePair{}, - }, - { - FieldID: 108, - Name: "float64_field", - Description: "field 108", - DataType: schemapb.DataType_Double, - TypeParams: []*commonpb.KeyValuePair{}, - IndexParams: []*commonpb.KeyValuePair{}, - }, - }, + } + sch.Fields = mf.GetFieldSchema() + for _, field := range sch.Fields { + if field.GetDataType() == pkDataType && field.FieldID >= 100 { + field.IsPrimaryKey = true + } } - collection := etcdpb.CollectionMeta{ + return &etcdpb.CollectionMeta{ ID: collectionID, Schema: &sch, CreateTime: Timestamp(1), SegmentIDs: make([]UniqueID, 0), PartitionIDs: []UniqueID{0}, } - return &collection +} + +func (mf *MetaFactory) GetFieldSchema() []*schemapb.FieldSchema { + fields := []*schemapb.FieldSchema{ + { + FieldID: 0, + Name: "RowID", + Description: "RowID field", + DataType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "f0_tk1", + Value: "f0_tv1", + }, + }, + }, + { + FieldID: 1, + Name: "Timestamp", + Description: "Timestamp field", + DataType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "f1_tk1", + Value: "f1_tv1", + }, + }, + }, + { + FieldID: 100, + Name: "float_vector_field", + Description: "field 100", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "2", + }, + }, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "indexkey", + Value: "indexvalue", + }, + }, + }, + { + FieldID: 101, + Name: "binary_vector_field", + Description: "field 101", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "32", + }, + }, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "indexkey", + Value: "indexvalue", + }, + }, + }, + { + FieldID: 102, + Name: "bool_field", + Description: "field 102", + DataType: schemapb.DataType_Bool, + TypeParams: []*commonpb.KeyValuePair{}, + IndexParams: []*commonpb.KeyValuePair{}, + }, + { + FieldID: 103, + Name: "int8_field", + Description: "field 103", + DataType: schemapb.DataType_Int8, + TypeParams: []*commonpb.KeyValuePair{}, + IndexParams: []*commonpb.KeyValuePair{}, + }, + { + FieldID: 104, + Name: "int16_field", + Description: "field 104", + DataType: schemapb.DataType_Int16, + TypeParams: []*commonpb.KeyValuePair{}, + IndexParams: []*commonpb.KeyValuePair{}, + }, + { + FieldID: 105, + Name: "int32_field", + Description: "field 105", + DataType: schemapb.DataType_Int32, + TypeParams: []*commonpb.KeyValuePair{}, + IndexParams: []*commonpb.KeyValuePair{}, + }, + { + FieldID: 106, + Name: "int64_field", + Description: "field 106", + DataType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{}, + IndexParams: []*commonpb.KeyValuePair{}, + }, + { + FieldID: 107, + Name: "float32_field", + Description: "field 107", + DataType: schemapb.DataType_Float, + TypeParams: []*commonpb.KeyValuePair{}, + IndexParams: []*commonpb.KeyValuePair{}, + }, + { + FieldID: 108, + Name: "float64_field", + Description: "field 108", + DataType: schemapb.DataType_Double, + TypeParams: []*commonpb.KeyValuePair{}, + IndexParams: []*commonpb.KeyValuePair{}, + }, + { + FieldID: 109, + Name: "varChar_field", + Description: "field 109", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "max_length_per_row", + Value: "100", + }, + }, + IndexParams: []*commonpb.KeyValuePair{}, + }, + } + + return fields } func NewDataFactory() *DataFactory { - return &DataFactory{rawData: GenRowData()} + return &DataFactory{rawData: GenRowData(), columnData: GenColumnData()} } func GenRowData() (rawData []byte) { @@ -427,6 +452,192 @@ func GenRowData() (rawData []byte) { return } +func GenColumnData() (fieldsData []*schemapb.FieldData) { + // Float vector + var fVector = []float32{1, 2} + floatVectorData := &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: "float_vector_field", + FieldId: 100, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 2, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: fVector, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, floatVectorData) + + // Binary vector + // Dimension of binary vector is 32 + // size := 4, = 32 / 8 + binaryVector := []byte{255, 255, 255, 0} + binaryVectorData := &schemapb.FieldData{ + Type: schemapb.DataType_BinaryVector, + FieldName: "binary_vector_field", + FieldId: 101, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 32, + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: binaryVector, + }, + }, + }, + } + fieldsData = append(fieldsData, binaryVectorData) + + // bool + boolData := []bool{true} + boolFieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Bool, + FieldName: "bool_field", + FieldId: 102, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: boolData, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, boolFieldData) + + // int8 + int8Data := []int32{100} + int8FieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Int8, + FieldName: "int8_field", + FieldId: 103, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: int8Data, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, int8FieldData) + + // int16 + int16Data := []int32{200} + int16FieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Int16, + FieldName: "int16_field", + FieldId: 104, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: int16Data, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, int16FieldData) + + // int32 + int32Data := []int32{300} + int32FieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Int32, + FieldName: "int32_field", + FieldId: 105, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: int32Data, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, int32FieldData) + + // int64 + int64Data := []int64{400} + int64FieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: "int64_field", + FieldId: 106, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: int64Data, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, int64FieldData) + + // float + floatData := []float32{1.1} + floatFieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Float, + FieldName: "float32_field", + FieldId: 107, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: floatData, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, floatFieldData) + + //double + doubleData := []float64{2.2} + doubleFieldData := &schemapb.FieldData{ + Type: schemapb.DataType_Double, + FieldName: "float64_field", + FieldId: 108, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: doubleData, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, doubleFieldData) + + //var char + varCharData := []string{"test"} + varCharFieldData := &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldName: "varChar_field", + FieldId: 109, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: varCharData, + }, + }, + }, + }, + } + fieldsData = append(fieldsData, varCharFieldData) + + return +} + func (df *DataFactory) GenMsgStreamInsertMsg(idx int, chanName string) *msgstream.InsertMsg { var msg = &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ @@ -446,7 +657,10 @@ func (df *DataFactory) GenMsgStreamInsertMsg(idx int, chanName string) *msgstrea ShardName: chanName, Timestamps: []Timestamp{Timestamp(idx + 1000)}, RowIDs: []UniqueID{UniqueID(idx)}, - RowData: []*commonpb.Blob{{Value: df.rawData}}, + // RowData: []*commonpb.Blob{{Value: df.rawData}}, + FieldsData: df.columnData, + Version: internalpb.InsertDataVersion_ColumnBased, + NumRows: 1, }, } return msg @@ -662,7 +876,7 @@ func (m *RootCoordFactory) ShowCollections(ctx context.Context, in *milvuspb.Sho func (m *RootCoordFactory) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { f := MetaFactory{} - meta := f.GetCollectionMeta(m.collectionID, m.collectionName) + meta := f.GetCollectionMeta(m.collectionID, m.collectionName, m.pkType) resp := &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -764,6 +978,10 @@ func genInsertData() *InsertData { NumRows: []int64{2}, Data: []float64{3.333, 3.334}, }, + 109: &s.StringFieldData{ + NumRows: []int64{2}, + Data: []string{"test1", "test2"}, + }, }} } @@ -816,6 +1034,10 @@ func genEmptyInsertData() *InsertData { NumRows: []int64{0}, Data: []float64{}, }, + 109: &s.StringFieldData{ + NumRows: []int64{0}, + Data: []string{}, + }, }} } @@ -868,6 +1090,10 @@ func genInsertDataWithExpiredTS() *InsertData { NumRows: []int64{2}, Data: []float64{3.333, 3.334}, }, + 109: &s.StringFieldData{ + NumRows: []int64{2}, + Data: []string{"test1", "test2"}, + }, }} } diff --git a/internal/datanode/segment_replica.go b/internal/datanode/segment_replica.go index e27358add3..0a9617d6af 100644 --- a/internal/datanode/segment_replica.go +++ b/internal/datanode/segment_replica.go @@ -19,7 +19,6 @@ package datanode import ( "context" "fmt" - "math" "sync" "sync/atomic" @@ -42,6 +41,10 @@ const ( maxBloomFalsePositive float64 = 0.005 ) +type PrimaryKey = storage.PrimaryKey +type Int64PrimaryKey = storage.Int64PrimaryKey +type StringPrimaryKey = storage.StringPrimaryKey + // Replica is DataNode unique replication type Replica interface { getCollectionID() UniqueID @@ -57,8 +60,8 @@ type Replica interface { listSegmentsCheckPoints() map[UniqueID]segmentCheckPoint updateSegmentEndPosition(segID UniqueID, endPos *internalpb.MsgPosition) updateSegmentCheckPoint(segID UniqueID) - updateSegmentPKRange(segID UniqueID, rowIDs []int64) - mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64) + updateSegmentPKRange(segID UniqueID, ids storage.FieldData) + mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64) error hasSegment(segID UniqueID, countFlushed bool) bool removeSegments(segID ...UniqueID) listCompactedSegmentIDs() map[UniqueID][]UniqueID @@ -87,8 +90,8 @@ type Segment struct { pkFilter *bloom.BloomFilter // bloom filter of pk inside a segment // TODO silverxia, needs to change to interface to support `string` type PK - minPK int64 // minimal pk value, shortcut for checking whether a pk is inside this segment - maxPK int64 // maximal pk value, same above + minPK PrimaryKey // minimal pk value, shortcut for checking whether a pk is inside this segment + maxPK PrimaryKey // maximal pk value, same above } // SegmentReplica is the data replication of persistent data in datanode. @@ -107,23 +110,58 @@ type SegmentReplica struct { chunkManager storage.ChunkManager } -func (s *Segment) updatePKRange(pks []int64) { - buf := make([]byte, 8) - for _, pk := range pks { - common.Endian.PutUint64(buf, uint64(pk)) - s.pkFilter.Add(buf) - if pk > s.maxPK { - s.maxPK = pk +func (s *Segment) updatePk(pk PrimaryKey) error { + if s.minPK == nil { + s.minPK = pk + } else if s.minPK.GT(pk) { + s.minPK = pk + } + + if s.maxPK == nil { + s.maxPK = pk + } else if s.maxPK.LT(pk) { + s.maxPK = pk + } + + return nil +} + +func (s *Segment) updatePKRange(ids storage.FieldData) error { + switch pks := ids.(type) { + case *storage.Int64FieldData: + buf := make([]byte, 8) + for _, pk := range pks.Data { + id := &Int64PrimaryKey{ + Value: pk, + } + err := s.updatePk(id) + if err != nil { + return err + } + common.Endian.PutUint64(buf, uint64(pk)) + s.pkFilter.Add(buf) } - if pk < s.minPK { - s.minPK = pk + case *storage.StringFieldData: + for _, pk := range pks.Data { + id := &StringPrimaryKey{ + Value: pk, + } + err := s.updatePk(id) + if err != nil { + return err + } + s.pkFilter.Add([]byte(pk)) } + default: + //TODO:: } log.Info("update pk range", zap.Int64("collectionID", s.collectionID), zap.Int64("partitionID", s.partitionID), zap.Int64("segmentID", s.segmentID), zap.String("channel", s.channelName), - zap.Int64("num_rows", s.numRows), zap.Int64("minPK", s.minPK), zap.Int64("maxPK", s.maxPK)) + zap.Int64("num_rows", s.numRows), zap.Any("minPK", s.minPK), zap.Any("maxPK", s.maxPK)) + + return nil } var _ Replica = &SegmentReplica{} @@ -241,8 +279,6 @@ func (replica *SegmentReplica) addNewSegment(segID, collID, partitionID UniqueID endPos: endPos, pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), - minPK: math.MaxInt64, // use max value, represents no value - maxPK: math.MinInt64, // use min value represents no value } seg.isNew.Store(true) @@ -328,9 +364,8 @@ func (replica *SegmentReplica) addNormalSegment(segID, collID, partitionID Uniqu numRows: numOfRows, pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), - minPK: math.MaxInt64, // use max value, represents no value - maxPK: math.MinInt64, // use min value represents no value } + if cp != nil { seg.checkPoint = *cp seg.endPos = &cp.pos @@ -378,8 +413,6 @@ func (replica *SegmentReplica) addFlushedSegment(segID, collID, partitionID Uniq //TODO silverxia, normal segments bloom filter and pk range should be loaded from serialized files pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), - minPK: math.MaxInt64, // use max value, represents no value - maxPK: math.MinInt64, // use min value represents no value } err := replica.initPKBloomFilter(seg, statsBinlogs, recoverTs) @@ -444,13 +477,8 @@ func (replica *SegmentReplica) initPKBloomFilter(s *Segment, statsBinlogs []*dat if err != nil { return err } - if s.minPK > stat.Min { - s.minPK = stat.Min - } - - if s.maxPK < stat.Max { - s.maxPK = stat.Max - } + s.updatePk(stat.MinPk) + s.updatePk(stat.MaxPk) } return nil } @@ -513,25 +541,25 @@ func (replica *SegmentReplica) updateSegmentEndPosition(segID UniqueID, endPos * log.Warn("No match segment", zap.Int64("ID", segID)) } -func (replica *SegmentReplica) updateSegmentPKRange(segID UniqueID, pks []int64) { +func (replica *SegmentReplica) updateSegmentPKRange(segID UniqueID, ids storage.FieldData) { replica.segMu.Lock() defer replica.segMu.Unlock() seg, ok := replica.newSegments[segID] if ok { - seg.updatePKRange(pks) + seg.updatePKRange(ids) return } seg, ok = replica.normalSegments[segID] if ok { - seg.updatePKRange(pks) + seg.updatePKRange(ids) return } seg, ok = replica.flushedSegments[segID] if ok { - seg.updatePKRange(pks) + seg.updatePKRange(ids) return } @@ -684,12 +712,12 @@ func (replica *SegmentReplica) updateSegmentCheckPoint(segID UniqueID) { log.Warn("There's no segment", zap.Int64("ID", segID)) } -func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64) { +func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planID UniqueID, compactedFrom []UniqueID, channelName string, numOfRows int64) error { if collID != replica.collectionID { log.Warn("Mismatch collection", zap.Int64("input ID", collID), zap.Int64("expected ID", replica.collectionID)) - return + return fmt.Errorf("mismatch collection, ID=%d", collID) } log.Info("merge flushed segments", @@ -708,8 +736,6 @@ func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planI numRows: numOfRows, pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), - minPK: math.MaxInt64, // use max value, represents no value - maxPK: math.MinInt64, // use min value represents no value } replica.segMu.Lock() @@ -735,15 +761,17 @@ func (replica *SegmentReplica) mergeFlushedSegments(segID, collID, partID, planI replica.segMu.Lock() replica.flushedSegments[segID] = seg replica.segMu.Unlock() + + return nil } // for tests only -func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID UniqueID, channelName string, numOfRows int64, pks []UniqueID) { +func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID UniqueID, channelName string, numOfRows int64, ids storage.FieldData) error { if collID != replica.collectionID { log.Warn("Mismatch collection", zap.Int64("input ID", collID), zap.Int64("expected ID", replica.collectionID)) - return + return fmt.Errorf("mismatch collection, ID=%d", collID) } log.Info("Add Flushed segment", @@ -761,11 +789,9 @@ func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID Un numRows: numOfRows, pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), - minPK: math.MaxInt64, // use max value, represents no value - maxPK: math.MinInt64, // use min value represents no value } - seg.updatePKRange(pks) + seg.updatePKRange(ids) seg.isNew.Store(false) seg.isFlushed.Store(true) @@ -773,6 +799,8 @@ func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID Un replica.segMu.Lock() replica.flushedSegments[segID] = seg replica.segMu.Unlock() + + return nil } func (replica *SegmentReplica) listAllSegmentIDs() []UniqueID { diff --git a/internal/datanode/segment_replica_test.go b/internal/datanode/segment_replica_test.go index 9754e1f0ec..403381d94b 100644 --- a/internal/datanode/segment_replica_test.go +++ b/internal/datanode/segment_replica_test.go @@ -20,7 +20,6 @@ import ( "context" "encoding/json" "fmt" - "math" "math/rand" "testing" @@ -31,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" ) @@ -50,7 +50,7 @@ type mockDataCM struct { } func (kv *mockDataCM) MultiRead(keys []string) ([][]byte, error) { - stats := &storage.Int64Stats{ + stats := &storage.PrimaryKeyStats{ FieldID: common.RowIDField, Min: 0, Max: 10, @@ -65,7 +65,7 @@ type mockPkfilterMergeError struct { } func (kv *mockPkfilterMergeError) MultiRead(keys []string) ([][]byte, error) { - stats := &storage.Int64Stats{ + stats := &storage.PrimaryKeyStats{ FieldID: common.RowIDField, Min: 0, Max: 10, @@ -170,7 +170,9 @@ func TestSegmentReplica_getCollectionAndPartitionID(te *testing.T) { } func TestSegmentReplica(t *testing.T) { - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } collID := UniqueID(1) cm := storage.NewLocalChunkManager(storage.RootPath(segmentReplicaNodeTestDir)) defer cm.RemoveWithPrefix("") @@ -252,7 +254,9 @@ func TestSegmentReplica(t *testing.T) { } func TestSegmentReplica_InterfaceMethod(t *testing.T) { - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } cm := storage.NewLocalChunkManager(storage.RootPath(segmentReplicaNodeTestDir)) defer cm.RemoveWithPrefix("") @@ -268,17 +272,20 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { {false, 1, 2, "invalid input collection with replica collection"}, } + primaryKeyData := &storage.Int64FieldData{ + Data: []int64{9}, + } for _, test := range tests { t.Run(test.description, func(t *testing.T) { replica, err := newReplica(context.TODO(), rc, cm, test.replicaCollID) require.NoError(t, err) if test.isvalid { - replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, []int64{9}) + replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, primaryKeyData) assert.True(t, replica.hasSegment(100, true)) assert.False(t, replica.hasSegment(100, false)) } else { - replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, []int64{9}) + replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, primaryKeyData) assert.False(t, replica.hasSegment(100, true)) assert.False(t, replica.hasSegment(100, false)) } @@ -572,6 +579,7 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { } }) } + rc.setCollectionID(1) }) t.Run("Test listAllSegmentIDs", func(t *testing.T) { @@ -628,8 +636,11 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { sr, err := newReplica(context.Background(), rc, cm, 1) assert.Nil(t, err) - sr.addFlushedSegmentWithPKs(1, 1, 0, "channel", 10, []UniqueID{1}) - sr.addFlushedSegmentWithPKs(2, 1, 0, "channel", 10, []UniqueID{1}) + primaryKeyData := &storage.Int64FieldData{ + Data: []UniqueID{1}, + } + sr.addFlushedSegmentWithPKs(1, 1, 0, "channel", 10, primaryKeyData) + sr.addFlushedSegmentWithPKs(2, 1, 0, "channel", 10, primaryKeyData) require.True(t, sr.hasSegment(1, true)) require.True(t, sr.hasSegment(2, true)) @@ -648,7 +659,9 @@ func TestSegmentReplica_InterfaceMethod(t *testing.T) { } func TestInnerFunctionSegment(t *testing.T) { - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } collID := UniqueID(1) cm := storage.NewLocalChunkManager(storage.RootPath(segmentReplicaNodeTestDir)) defer cm.RemoveWithPrefix("") @@ -747,8 +760,6 @@ func TestInnerFunctionSegment(t *testing.T) { func TestSegmentReplica_UpdatePKRange(t *testing.T) { seg := &Segment{ pkFilter: bloom.NewWithEstimates(100000, 0.005), - maxPK: math.MinInt64, - minPK: math.MaxInt64, } cases := make([]int64, 0, 100) @@ -757,10 +768,16 @@ func TestSegmentReplica_UpdatePKRange(t *testing.T) { } buf := make([]byte, 8) for _, c := range cases { - seg.updatePKRange([]int64{c}) + seg.updatePKRange(&storage.Int64FieldData{ + Data: []int64{c}, + }) - assert.LessOrEqual(t, seg.minPK, c) - assert.GreaterOrEqual(t, seg.maxPK, c) + pk := &Int64PrimaryKey{ + Value: c, + } + + assert.Equal(t, true, seg.minPK.LE(pk)) + assert.Equal(t, true, seg.maxPK.GE(pk)) common.Endian.PutUint64(buf, uint64(c)) assert.True(t, seg.pkFilter.Test(buf)) @@ -768,7 +785,9 @@ func TestSegmentReplica_UpdatePKRange(t *testing.T) { } func TestReplica_UpdatePKRange(t *testing.T) { - rc := &RootCoordFactory{} + rc := &RootCoordFactory{ + pkType: schemapb.DataType_Int64, + } collID := UniqueID(1) partID := UniqueID(2) chanName := "insert-02" @@ -797,14 +816,19 @@ func TestReplica_UpdatePKRange(t *testing.T) { } buf := make([]byte, 8) for _, c := range cases { - replica.updateSegmentPKRange(1, []int64{c}) // new segment - replica.updateSegmentPKRange(2, []int64{c}) // normal segment - replica.updateSegmentPKRange(3, []int64{c}) // non-exist segment + replica.updateSegmentPKRange(1, &storage.Int64FieldData{Data: []int64{c}}) // new segment + replica.updateSegmentPKRange(2, &storage.Int64FieldData{Data: []int64{c}}) // normal segment + replica.updateSegmentPKRange(3, &storage.Int64FieldData{Data: []int64{c}}) // non-exist segment - assert.LessOrEqual(t, segNew.minPK, c) - assert.GreaterOrEqual(t, segNew.maxPK, c) - assert.LessOrEqual(t, segNormal.minPK, c) - assert.GreaterOrEqual(t, segNormal.maxPK, c) + pk := &Int64PrimaryKey{ + Value: c, + } + + assert.Equal(t, true, segNew.minPK.LE(pk)) + assert.Equal(t, true, segNew.maxPK.GE(pk)) + + assert.Equal(t, true, segNormal.minPK.LE(pk)) + assert.Equal(t, true, segNormal.maxPK.GE(pk)) common.Endian.PutUint64(buf, uint64(c)) assert.True(t, segNew.pkFilter.Test(buf)) diff --git a/internal/mq/msgstream/msg.go b/internal/mq/msgstream/msg.go index 1c47860010..616abd8237 100644 --- a/internal/mq/msgstream/msg.go +++ b/internal/mq/msgstream/msg.go @@ -19,9 +19,11 @@ package msgstream import ( "context" "errors" + "fmt" "time" "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/golang/protobuf/proto" @@ -189,9 +191,32 @@ func (it *InsertMsg) NRows() uint64 { return it.InsertRequest.GetNumRows() } -func (it *InsertMsg) CheckAligned() bool { - return len(it.GetRowIDs()) == len(it.GetTimestamps()) && - uint64(len(it.GetRowIDs())) == it.NRows() +func (it *InsertMsg) CheckAligned() error { + numRowsOfFieldDataMismatch := func(fieldName string, fieldNumRows, passedNumRows uint64) error { + return fmt.Errorf("the num_rows(%d) of %sth field is not equal to passed NumRows(%d)", fieldNumRows, fieldName, passedNumRows) + } + rowNums := it.NRows() + if it.IsColumnBased() { + for _, field := range it.FieldsData { + fieldNumRows, err := funcutil.GetNumRowOfFieldData(field) + if err != nil { + return err + } + if fieldNumRows != rowNums { + return numRowsOfFieldDataMismatch(field.FieldName, fieldNumRows, rowNums) + } + } + } + + if len(it.GetRowIDs()) != len(it.GetTimestamps()) { + return fmt.Errorf("the num_rows(%d) of rowIDs is not equal to the num_rows(%d) of timestamps", len(it.GetRowIDs()), len(it.GetTimestamps())) + } + + if uint64(len(it.GetRowIDs())) != it.NRows() { + return fmt.Errorf("the num_rows(%d) of rowIDs is not equal to passed NumRows(%d)", len(it.GetRowIDs()), it.NRows()) + } + + return nil } func (it *InsertMsg) rowBasedIndexRequest(index int) internalpb.InsertRequest { diff --git a/internal/mq/msgstream/msg_test.go b/internal/mq/msgstream/msg_test.go index 868b7023f1..c5a00c876c 100644 --- a/internal/mq/msgstream/msg_test.go +++ b/internal/mq/msgstream/msg_test.go @@ -191,14 +191,26 @@ func TestInsertMsg_CheckAligned(t *testing.T) { Version: internalpb.InsertDataVersion_RowBased, }, } - assert.True(t, msg1.CheckAligned()) + msg1.InsertRequest.NumRows = 1 + assert.NoError(t, msg1.CheckAligned()) msg1.InsertRequest.RowData = nil msg1.InsertRequest.FieldsData = []*schemapb.FieldData{ - {}, + { + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + }, + }, } - msg1.InsertRequest.NumRows = 1 + msg1.Version = internalpb.InsertDataVersion_ColumnBased - assert.True(t, msg1.CheckAligned()) + assert.NoError(t, msg1.CheckAligned()) } func TestInsertMsg_IndexMsg(t *testing.T) { diff --git a/internal/mq/msgstream/repack_func.go b/internal/mq/msgstream/repack_func.go index 71ca4d25a5..2067b9fa85 100644 --- a/internal/mq/msgstream/repack_func.go +++ b/internal/mq/msgstream/repack_func.go @@ -37,7 +37,10 @@ func InsertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e keysLen := len(keys) - if !insertRequest.CheckAligned() || insertRequest.NRows() != uint64(keysLen) { + if err := insertRequest.CheckAligned(); err != nil { + return nil, err + } + if insertRequest.NRows() != uint64(keysLen) { return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal") } for index, key := range keys { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index c3837eb48c..b702bccc36 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2177,18 +2177,22 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) it := &insertTask{ ctx: ctx, Condition: NewTaskCondition(ctx), - req: request, + // req: request, BaseInsertTask: BaseInsertTask{ BaseMsg: msgstream.BaseMsg{ HashValues: request.HashKeys, }, InsertRequest: internalpb.InsertRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: 0, + MsgType: commonpb.MsgType_Insert, + MsgID: 0, + SourceID: Params.ProxyCfg.ProxyID, }, CollectionName: request.CollectionName, PartitionName: request.PartitionName, + FieldsData: request.FieldsData, + NumRows: uint64(request.NumRows), + Version: internalpb.InsertDataVersion_ColumnBased, // RowData: transfer column based request to this }, }, @@ -2203,7 +2207,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) } constructFailedResponse := func(err error) *milvuspb.MutationResult { - numRows := it.req.NumRows + numRows := request.NumRows errIndex := make([]uint32, numRows) for i := uint32(0); i < numRows; i++ { errIndex[i] = i @@ -2256,7 +2260,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) if it.result.Status.ErrorCode != commonpb.ErrorCode_Success { setErrorIndex := func() { - numRows := it.req.NumRows + numRows := request.NumRows errIndex := make([]uint32, numRows) for i := uint32(0); i < numRows; i++ { errIndex[i] = i @@ -2268,7 +2272,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) } // InsertCnt always equals to the number of entities in the request - it.result.InsertCnt = int64(it.req.NumRows) + it.result.InsertCnt = int64(request.NumRows) metrics.ProxyInsertCount.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), metrics.SuccessLabel).Inc() diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 7047e833ed..53d96ae702 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -395,6 +395,99 @@ func newSimpleMockMsgStreamFactory() *simpleMockMsgStreamFactory { return &simpleMockMsgStreamFactory{} } +func generateFieldData(dataType schemapb.DataType, fieldName string, fieldID int64, numRows int) *schemapb.FieldData { + fieldData := &schemapb.FieldData{ + Type: dataType, + FieldName: fieldName, + FieldId: fieldID, + } + switch dataType { + case schemapb.DataType_Bool: + fieldData.FieldName = testBoolField + fieldData.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: generateBoolArray(numRows), + }, + }, + }, + } + case schemapb.DataType_Int32: + fieldData.FieldName = testInt32Field + fieldData.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: generateInt32Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Int64: + fieldData.FieldName = testInt64Field + fieldData.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: generateInt64Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Float: + fieldData.FieldName = testFloatField + fieldData.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: generateFloat32Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Double: + fieldData.FieldName = testDoubleField + fieldData.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: generateFloat64Array(numRows), + }, + }, + }, + } + case schemapb.DataType_VarChar: + //TODO:: + case schemapb.DataType_FloatVector: + fieldData.FieldName = testFloatVecField + fieldData.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(testVecDim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: generateFloatVectors(numRows, testVecDim), + }, + }, + }, + } + case schemapb.DataType_BinaryVector: + fieldData.FieldName = testBinaryVecField + fieldData.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(testVecDim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: generateBinaryVectors(numRows, testVecDim), + }, + }, + } + default: + //TODO:: + } + + return fieldData +} + func generateBoolArray(numRows int) []bool { ret := make([]bool, 0, numRows) for i := 0; i < numRows; i++ { @@ -435,6 +528,14 @@ func generateInt64Array(numRows int) []int64 { return ret } +func generateUint64Array(numRows int) []uint64 { + ret := make([]uint64, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, rand.Uint64()) + } + return ret +} + func generateFloat32Array(numRows int) []float32 { ret := make([]float32, 0, numRows) for i := 0; i < numRows; i++ { @@ -470,14 +571,23 @@ func generateBinaryVectors(numRows, dim int) []byte { return ret } -func newScalarFieldData(dType schemapb.DataType, fieldName string, numRows int) *schemapb.FieldData { +func generateVarCharArray(numRows int, maxLen int) []string { + ret := make([]string, numRows) + for i := 0; i < numRows; i++ { + ret[i] = funcutil.RandomString(rand.Intn(maxLen)) + } + + return ret +} + +func newScalarFieldData(fieldSchema *schemapb.FieldSchema, fieldName string, numRows int) *schemapb.FieldData { ret := &schemapb.FieldData{ - Type: dType, + Type: fieldSchema.DataType, FieldName: fieldName, Field: nil, } - switch dType { + switch fieldSchema.DataType { case schemapb.DataType_Bool: ret.Field = &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ @@ -548,6 +658,16 @@ func newScalarFieldData(dType schemapb.DataType, fieldName string, numRows int) }, }, } + case schemapb.DataType_VarChar: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: generateVarCharArray(numRows, testMaxVarCharLength), + }, + }, + }, + } } return ret diff --git a/internal/proxy/task.go b/internal/proxy/task.go index db7fde30ea..0400e2bb0a 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -21,13 +21,10 @@ import ( "errors" "fmt" "math" - "reflect" "regexp" - "sort" "strconv" "strings" "time" - "unsafe" "github.com/golang/protobuf/proto" "go.uber.org/zap" @@ -122,7 +119,7 @@ type BaseInsertTask = msgstream.InsertMsg type insertTask struct { BaseInsertTask - req *milvuspb.InsertRequest + // req *milvuspb.InsertRequest Condition ctx context.Context @@ -208,40 +205,9 @@ func (it *insertTask) getChannels() ([]pChan, error) { } func (it *insertTask) OnEnqueue() error { - it.BaseInsertTask.InsertRequest.Base = &commonpb.MsgBase{} return nil } -func getNumRowsOfScalarField(datas interface{}) uint32 { - realTypeDatas := reflect.ValueOf(datas) - return uint32(realTypeDatas.Len()) -} - -func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint32, error) { - if dim <= 0 { - return 0, errDimLessThanOrEqualToZero(int(dim)) - } - l := len(fDatas) - if int64(l)%dim != 0 { - return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim) - } - return uint32(int(int64(l) / dim)), nil -} - -func getNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint32, error) { - if dim <= 0 { - return 0, errDimLessThanOrEqualToZero(int(dim)) - } - if dim%8 != 0 { - return 0, errDimShouldDivide8(int(dim)) - } - l := len(bDatas) - if (8*int64(l))%dim != 0 { - return 0, fmt.Errorf("the num(%d) of all bits should divide the dim(%d)", 8*l, dim) - } - return uint32(int((8 * int64(l)) / dim)), nil -} - func (it *insertTask) checkLengthOfFieldsData() error { neededFieldsNum := 0 for _, field := range it.schema.Fields { @@ -250,236 +216,62 @@ func (it *insertTask) checkLengthOfFieldsData() error { } } - if len(it.req.FieldsData) < neededFieldsNum { - return errFieldsLessThanNeeded(len(it.req.FieldsData), neededFieldsNum) + if len(it.FieldsData) < neededFieldsNum { + return errFieldsLessThanNeeded(len(it.FieldsData), neededFieldsNum) } return nil } -func (it *insertTask) checkRowNums() error { - if it.req.NumRows <= 0 { - return errNumRowsLessThanOrEqualToZero(it.req.NumRows) - } - - if err := it.checkLengthOfFieldsData(); err != nil { - return err - } - - rowNums := it.req.NumRows - - for i, field := range it.req.FieldsData { - switch field.Field.(type) { - case *schemapb.FieldData_Scalars: - scalarField := field.GetScalars() - switch scalarField.Data.(type) { - case *schemapb.ScalarField_BoolData: - fieldNumRows := getNumRowsOfScalarField(scalarField.GetBoolData().Data) - if fieldNumRows != rowNums { - return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) - } - case *schemapb.ScalarField_IntData: - fieldNumRows := getNumRowsOfScalarField(scalarField.GetIntData().Data) - if fieldNumRows != rowNums { - return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) - } - case *schemapb.ScalarField_LongData: - fieldNumRows := getNumRowsOfScalarField(scalarField.GetLongData().Data) - if fieldNumRows != rowNums { - return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) - } - case *schemapb.ScalarField_FloatData: - fieldNumRows := getNumRowsOfScalarField(scalarField.GetFloatData().Data) - if fieldNumRows != rowNums { - return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) - } - case *schemapb.ScalarField_DoubleData: - fieldNumRows := getNumRowsOfScalarField(scalarField.GetDoubleData().Data) - if fieldNumRows != rowNums { - return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) - } - case *schemapb.ScalarField_BytesData: - return errUnsupportedDType("bytes") - case *schemapb.ScalarField_StringData: - return errUnsupportedDType("string") - case nil: - continue - default: - continue - } - case *schemapb.FieldData_Vectors: - vectorField := field.GetVectors() - switch vectorField.Data.(type) { - case *schemapb.VectorField_FloatVector: - dim := vectorField.GetDim() - fieldNumRows, err := getNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim) - if err != nil { - return err - } - if fieldNumRows != rowNums { - return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) - } - case *schemapb.VectorField_BinaryVector: - dim := vectorField.GetDim() - fieldNumRows, err := getNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim) - if err != nil { - return err - } - if fieldNumRows != rowNums { - return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) - } - case nil: - continue - default: - continue - } - } - } - - return nil -} - -func (it *insertTask) checkFieldAutoIDAndHashPK() error { +func (it *insertTask) checkPrimaryFieldData() error { + rowNums := uint32(it.NRows()) // TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields - if it.req.NumRows <= 0 { - return errNumRowsLessThanOrEqualToZero(it.req.NumRows) + if it.NRows() <= 0 { + return errNumRowsLessThanOrEqualToZero(rowNums) } if err := it.checkLengthOfFieldsData(); err != nil { return err } - rowNums := it.req.NumRows - - primaryFieldName := "" - autoIDFieldName := "" - autoIDLoc := -1 - primaryLoc := -1 - fields := it.schema.Fields - - for loc, field := range fields { - if field.AutoID { - autoIDLoc = loc - autoIDFieldName = field.Name - } - if field.IsPrimaryKey { - primaryLoc = loc - primaryFieldName = field.Name - } + primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema) + if err != nil { + log.Error("get primary field schema failed", zap.String("collection name", it.CollectionName), zap.Any("schema", it.schema), zap.Error(err)) + return err } - if primaryLoc < 0 { - return fmt.Errorf("primary field is not found") - } - - if autoIDLoc >= 0 && autoIDLoc != primaryLoc { - return fmt.Errorf("currently auto id field is only supported on primary field") - } - - var primaryField *schemapb.FieldData - var primaryData []int64 - for _, field := range it.req.FieldsData { - if field.FieldName == autoIDFieldName { - return fmt.Errorf("autoID field (%v) does not require data", autoIDFieldName) + // get primaryFieldData whether autoID is true or not + var primaryFieldData *schemapb.FieldData + if !primaryFieldSchema.AutoID { + primaryFieldData, err = getPrimaryFieldData(it.GetFieldsData(), primaryFieldSchema) + if err != nil { + log.Error("get primary field data failed", zap.String("collection name", it.CollectionName), zap.Error(err)) + return err } - if field.FieldName == primaryFieldName { - primaryField = field - } - } - - if primaryField != nil { - if primaryField.Type != schemapb.DataType_Int64 { - return fmt.Errorf("currently only support DataType Int64 as PrimaryField and Enable autoID") - } - switch primaryField.Field.(type) { - case *schemapb.FieldData_Scalars: - scalarField := primaryField.GetScalars() - switch scalarField.Data.(type) { - case *schemapb.ScalarField_LongData: - primaryData = scalarField.GetLongData().Data - default: - return fmt.Errorf("currently only support DataType Int64 as PrimaryField and Enable autoID") - } - default: - return fmt.Errorf("currently only support DataType Int64 as PrimaryField and Enable autoID") - } - it.result.IDs.IdField = &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: primaryData, - }, - } - } - - var rowIDBegin UniqueID - var rowIDEnd UniqueID - - tr := timerecord.NewTimeRecorder("applyPK") - rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(rowNums) - metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(tr.ElapseSpan())) - - it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums) - for i := rowIDBegin; i < rowIDEnd; i++ { - offset := i - rowIDBegin - it.BaseInsertTask.RowIDs[offset] = i - } - - if autoIDLoc >= 0 { - fieldData := schemapb.FieldData{ - FieldName: primaryFieldName, - FieldId: -1, - Type: schemapb.DataType_Int64, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: it.BaseInsertTask.RowIDs, - }, - }, - }, - }, - } - - // TODO(dragondriver): when we can ignore the order of input fields, use append directly - // it.req.FieldsData = append(it.req.FieldsData, &fieldData) - it.req.FieldsData = append(it.req.FieldsData, &schemapb.FieldData{}) - copy(it.req.FieldsData[autoIDLoc+1:], it.req.FieldsData[autoIDLoc:]) - it.req.FieldsData[autoIDLoc] = &fieldData - - it.result.IDs.IdField = &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: it.BaseInsertTask.RowIDs, - }, - } - it.HashPK(it.BaseInsertTask.RowIDs) } else { - it.HashPK(primaryData) + // if autoID == true, currently only support autoID for int64 PrimaryField + primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, it.RowIDs) + if err != nil { + log.Error("generate primary field data failed when autoID == true", zap.String("collection name", it.CollectionName), zap.Error(err)) + return err + } + // if autoID == true, set the primary field data + it.FieldsData = append(it.FieldsData, primaryFieldData) } - sliceIndex := make([]uint32, rowNums) - for i := uint32(0); i < rowNums; i++ { - sliceIndex[i] = i + // parse primaryFieldData to result.IDs, and as returned primary keys + it.result.IDs, err = parsePrimaryFieldData2IDs(primaryFieldData) + if err != nil { + log.Error("parse primary field data to IDs failed", zap.String("collection name", it.CollectionName), zap.Error(err)) + return err } - it.result.SuccIndex = sliceIndex return nil } -func (it *insertTask) HashPK(pks []int64) { - if len(it.HashValues) != 0 { - log.Warn("the hashvalues passed through client is not supported now, and will be overwritten") - } - it.HashValues = make([]uint32, 0, len(pks)) - for _, pk := range pks { - hash, _ := typeutil.Hash32Int64(pk) - it.HashValues = append(it.HashValues, hash) - } -} - func (it *insertTask) PreExecute(ctx context.Context) error { sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute") defer sp.Finish() - it.Base.MsgType = commonpb.MsgType_Insert - it.Base.SourceID = Params.ProxyCfg.ProxyID it.result = &milvuspb.MutationResult{ Status: &commonpb.Status{ @@ -491,275 +283,199 @@ func (it *insertTask) PreExecute(ctx context.Context) error { Timestamp: it.EndTs(), } - collectionName := it.BaseInsertTask.CollectionName + collectionName := it.CollectionName if err := validateCollectionName(collectionName); err != nil { + log.Error("valid collection name failed", zap.String("collection name", collectionName), zap.Error(err)) return err } - partitionTag := it.BaseInsertTask.PartitionName + partitionTag := it.PartitionName if err := validatePartitionTag(partitionTag, true); err != nil { + log.Error("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err)) return err } collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName) - log.Debug("Proxy Insert PreExecute", zap.Any("collSchema", collSchema)) if err != nil { + log.Error("get collection schema from global meta cache failed", zap.String("collection name", collectionName), zap.Error(err)) return err } it.schema = collSchema - err = it.checkRowNums() - if err != nil { - return err - } + rowNums := uint32(it.NRows()) + // set insertTask.rowIDs + var rowIDBegin UniqueID + var rowIDEnd UniqueID + tr := timerecord.NewTimeRecorder("applyPK") + rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(rowNums) + metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(tr.ElapseSpan())) - err = it.checkFieldAutoIDAndHashPK() - if err != nil { - return err + it.RowIDs = make([]UniqueID, rowNums) + for i := rowIDBegin; i < rowIDEnd; i++ { + offset := i - rowIDBegin + it.RowIDs[offset] = i } - - it.BaseInsertTask.InsertRequest.Version = internalpb.InsertDataVersion_ColumnBased - it.BaseInsertTask.InsertRequest.FieldsData = it.req.GetFieldsData() - it.BaseInsertTask.InsertRequest.NumRows = uint64(it.req.GetNumRows()) - err = typeutil.FillFieldBySchema(it.BaseInsertTask.InsertRequest.GetFieldsData(), collSchema) - if err != nil { - return err - } - - rowNum := it.req.NumRows + // set insertTask.timeStamps + rowNum := it.NRows() it.Timestamps = make([]uint64, rowNum) for index := range it.Timestamps { it.Timestamps[index] = it.BeginTimestamp } + // set result.SuccIndex + sliceIndex := make([]uint32, rowNums) + for i := uint32(0); i < rowNums; i++ { + sliceIndex[i] = i + } + it.result.SuccIndex = sliceIndex + + // check primaryFieldData whether autoID is true or not + // set rowIDs as primary data if autoID == true + err = it.checkPrimaryFieldData() + if err != nil { + log.Error("check primary field data and hash primary key failed", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err)) + return err + } + + // set field ID to insert field data + err = fillFieldIDBySchema(it.GetFieldsData(), collSchema) + if err != nil { + log.Error("set fieldID to fieldData failed", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err)) + return err + } + + // check that all field's number rows are equal + if err = it.CheckAligned(); err != nil { + log.Error("field data is not aligned", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err)) + return err + } + + log.Debug("Proxy Insert PreExecute done", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName)) + return nil } -func (it *insertTask) _assignSegmentID(stream msgstream.MsgStream, pack *msgstream.MsgPack) (*msgstream.MsgPack, error) { - newPack := &msgstream.MsgPack{ - BeginTs: pack.BeginTs, - EndTs: pack.EndTs, - StartPositions: pack.StartPositions, - EndPositions: pack.EndPositions, - Msgs: nil, - } - tsMsgs := pack.Msgs - hashKeys := stream.ComputeProduceChannelIndexes(tsMsgs) - if len(hashKeys) == 0 { - return nil, fmt.Errorf("the length of hashKeys is 0") - } - reqID := it.Base.MsgID - channelCountMap := make(map[int32]uint32) // channelID to count - channelMaxTSMap := make(map[int32]Timestamp) // channelID to max Timestamp - channelNames, err := it.chMgr.getVChannels(it.GetCollectionID()) - if err != nil { - return nil, err - } - log.Debug("_assignSemgentID, produceChannels:", zap.Any("Channels", channelNames)) +func (it *insertTask) assignSegmentID(channelNames []string) (*msgstream.MsgPack, error) { + threshold := Params.PulsarCfg.MaxMessageSize - for i, request := range tsMsgs { - if request.Type() != commonpb.MsgType_Insert { - return nil, fmt.Errorf("msg's must be Insert") - } - insertRequest, ok := request.(*msgstream.InsertMsg) - if !ok { - return nil, fmt.Errorf("msg's must be Insert") - } - - keys := hashKeys[i] - keysLen := len(keys) - - if !insertRequest.CheckAligned() { - return nil, - fmt.Errorf("the length of timestamps(%d), rowIDs(%d) and num_rows(%d) are not equal", - len(insertRequest.GetTimestamps()), - len(insertRequest.GetRowIDs()), - insertRequest.NRows()) - } - if uint64(keysLen) != insertRequest.NRows() { - return nil, - fmt.Errorf( - "the length of hashValue(%d), num_rows(%d) are not equal", - keysLen, insertRequest.NRows()) - } - for idx, channelID := range keys { - channelCountMap[channelID]++ - if _, ok := channelMaxTSMap[channelID]; !ok { - channelMaxTSMap[channelID] = typeutil.ZeroTimestamp - } - ts := insertRequest.Timestamps[idx] - if channelMaxTSMap[channelID] < ts { - channelMaxTSMap[channelID] = ts - } - } + result := &msgstream.MsgPack{ + BeginTs: it.BeginTs(), + EndTs: it.EndTs(), } - reqSegCountMap := make(map[int32]map[UniqueID]uint32) - for channelID, count := range channelCountMap { - ts, ok := channelMaxTSMap[channelID] - if !ok { - ts = typeutil.ZeroTimestamp - log.Debug("Warning: did not get max Timestamp!") - } + // generate hash value for every primary key + if len(it.HashValues) != 0 { + log.Warn("the hashvalues passed through client is not supported now, and will be overwritten") + } + it.HashValues = typeutil.HashPK2Channels(it.result.IDs, channelNames) + // groupedHashKeys represents the dmChannel index + channel2RowOffsets := make(map[string][]int) // channelName to count + channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp + + // assert len(it.hashValues) < maxInt + for offset, channelID := range it.HashValues { channelName := channelNames[channelID] - if channelName == "" { - return nil, fmt.Errorf("proxy, repack_func, can not found channelName") + if _, ok := channel2RowOffsets[channelName]; !ok { + channel2RowOffsets[channelName] = []int{} } - mapInfo, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, count, ts) + channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset) + + if _, ok := channelMaxTSMap[channelName]; !ok { + channelMaxTSMap[channelName] = typeutil.ZeroTimestamp + } + ts := it.Timestamps[offset] + if channelMaxTSMap[channelName] < ts { + channelMaxTSMap[channelName] = ts + } + } + + // create empty insert message + createInsertMsg := func(segmentID UniqueID, channelName string) *msgstream.InsertMsg { + insertReq := internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: it.Base.MsgID, + Timestamp: it.BeginTimestamp, // entity's timestamp was set to equal it.BeginTimestamp in preExecute() + SourceID: it.Base.SourceID, + }, + CollectionID: it.CollectionID, + PartitionID: it.PartitionID, + CollectionName: it.CollectionName, + PartitionName: it.PartitionName, + SegmentID: segmentID, + ShardName: channelName, + Version: internalpb.InsertDataVersion_ColumnBased, + } + insertReq.FieldsData = make([]*schemapb.FieldData, len(it.GetFieldsData())) + + insertMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: it.TraceCtx(), + }, + InsertRequest: insertReq, + } + + return insertMsg + } + + // repack the row data corresponding to the offset to insertMsg + getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, mexMessageSize int) ([]msgstream.TsMsg, error) { + repackedMsgs := make([]msgstream.TsMsg, 0) + requestSize := 0 + insertMsg := createInsertMsg(segmentID, channelName) + for _, offset := range rowOffsets { + curRowMessageSize, err := typeutil.EstimateEntitySize(it.InsertRequest.GetFieldsData(), offset) + if err != nil { + return nil, err + } + + // if insertMsg's size is greater than the threshold, split into multiple insertMsgs + if requestSize+curRowMessageSize >= mexMessageSize { + repackedMsgs = append(repackedMsgs, insertMsg) + insertMsg = createInsertMsg(segmentID, channelName) + requestSize = 0 + } + + typeutil.AppendFieldData(insertMsg.FieldsData, it.GetFieldsData(), int64(offset)) + insertMsg.HashValues = append(insertMsg.HashValues, it.HashValues[offset]) + insertMsg.Timestamps = append(insertMsg.Timestamps, it.Timestamps[offset]) + insertMsg.RowIDs = append(insertMsg.RowIDs, it.RowIDs[offset]) + insertMsg.NumRows++ + requestSize += curRowMessageSize + } + repackedMsgs = append(repackedMsgs, insertMsg) + + return repackedMsgs, nil + } + + // get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID + for channelName, rowOffsets := range channel2RowOffsets { + assignedSegmentInfos, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName]) if err != nil { - log.Debug("insertTask.go", zap.Any("MapInfo", mapInfo), + log.Error("allocate segmentID for insert data failed", + zap.Int64("collectionID", it.CollectionID), + zap.String("channel name", channelName), + zap.Int("allocate count", len(rowOffsets)), zap.Error(err)) return nil, err } - reqSegCountMap[channelID] = make(map[UniqueID]uint32) - reqSegCountMap[channelID] = mapInfo - log.Debug("Proxy", zap.Int64("repackFunc, reqSegCountMap, reqID", reqID), zap.Any("mapinfo", mapInfo)) - } - reqSegAccumulateCountMap := make(map[int32][]uint32) - reqSegIDMap := make(map[int32][]UniqueID) - reqSegAllocateCounter := make(map[int32]uint32) - - for channelID, segInfo := range reqSegCountMap { - reqSegAllocateCounter[channelID] = 0 - keys := make([]UniqueID, len(segInfo)) - i := 0 - for key := range segInfo { - keys[i] = key - i++ - } - sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) - accumulate := uint32(0) - for _, key := range keys { - accumulate += segInfo[key] - if _, ok := reqSegAccumulateCountMap[channelID]; !ok { - reqSegAccumulateCountMap[channelID] = make([]uint32, 0) + startPos := 0 + for segmentID, count := range assignedSegmentInfos { + subRowOffsets := rowOffsets[startPos : startPos+int(count)] + insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold) + if err != nil { + log.Error("repack insert data to insert msgs failed", + zap.Int64("collectionID", it.CollectionID), + zap.Error(err)) + return nil, err } - reqSegAccumulateCountMap[channelID] = append( - reqSegAccumulateCountMap[channelID], - accumulate, - ) - if _, ok := reqSegIDMap[channelID]; !ok { - reqSegIDMap[channelID] = make([]UniqueID, 0) - } - reqSegIDMap[channelID] = append( - reqSegIDMap[channelID], - key, - ) + result.Msgs = append(result.Msgs, insertMsgs...) + startPos += int(count) } } - var getSegmentID = func(channelID int32) UniqueID { - reqSegAllocateCounter[channelID]++ - cur := reqSegAllocateCounter[channelID] - accumulateSlice := reqSegAccumulateCountMap[channelID] - segIDSlice := reqSegIDMap[channelID] - for index, count := range accumulateSlice { - if cur <= count { - return segIDSlice[index] - } - } - log.Warn("Can't Found SegmentID", zap.Any("reqSegAllocateCounter", reqSegAllocateCounter)) - return 0 - } - - threshold := Params.PulsarCfg.MaxMessageSize - // not accurate - /* #nosec G103 */ - getFixedSizeOfInsertMsg := func(msg *msgstream.InsertMsg) int { - size := 0 - - size += int(unsafe.Sizeof(*msg.Base)) - size += int(unsafe.Sizeof(msg.DbName)) - size += int(unsafe.Sizeof(msg.CollectionName)) - size += int(unsafe.Sizeof(msg.PartitionName)) - size += int(unsafe.Sizeof(msg.DbID)) - size += int(unsafe.Sizeof(msg.CollectionID)) - size += int(unsafe.Sizeof(msg.PartitionID)) - size += int(unsafe.Sizeof(msg.SegmentID)) - size += int(unsafe.Sizeof(msg.ShardName)) - size += int(unsafe.Sizeof(msg.Timestamps)) - size += int(unsafe.Sizeof(msg.RowIDs)) - return size - } - - sizePerRow, _ := typeutil.EstimateSizePerRecord(it.schema) - result := make(map[int32]msgstream.TsMsg) - curMsgSizeMap := make(map[int32]int) - - for i, request := range tsMsgs { - insertRequest := request.(*msgstream.InsertMsg) - keys := hashKeys[i] - collectionName := insertRequest.CollectionName - collectionID := insertRequest.CollectionID - partitionID := insertRequest.PartitionID - partitionName := insertRequest.PartitionName - proxyID := insertRequest.Base.SourceID - for index, key := range keys { - ts := insertRequest.Timestamps[index] - rowID := insertRequest.RowIDs[index] - segmentID := getSegmentID(key) - if segmentID == 0 { - return nil, fmt.Errorf("get SegmentID failed, segmentID is zero") - } - _, ok := result[key] - if !ok { - sliceRequest := internalpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: reqID, - Timestamp: ts, - SourceID: proxyID, - }, - CollectionID: collectionID, - PartitionID: partitionID, - CollectionName: collectionName, - PartitionName: partitionName, - SegmentID: segmentID, - ShardName: channelNames[key], - } - - sliceRequest.Version = internalpb.InsertDataVersion_ColumnBased - sliceRequest.NumRows = 0 - sliceRequest.FieldsData = make([]*schemapb.FieldData, len(it.BaseInsertTask.InsertRequest.GetFieldsData())) - - insertMsg := &msgstream.InsertMsg{ - BaseMsg: msgstream.BaseMsg{ - Ctx: request.TraceCtx(), - }, - InsertRequest: sliceRequest, - } - result[key] = insertMsg - curMsgSizeMap[key] = getFixedSizeOfInsertMsg(insertMsg) - } - - curMsg := result[key].(*msgstream.InsertMsg) - curMsgSize := curMsgSizeMap[key] - - curMsg.HashValues = append(curMsg.HashValues, insertRequest.HashValues[index]) - curMsg.Timestamps = append(curMsg.Timestamps, ts) - curMsg.RowIDs = append(curMsg.RowIDs, rowID) - - typeutil.AppendFieldData(curMsg.FieldsData, it.BaseInsertTask.InsertRequest.GetFieldsData(), int64(index)) - curMsg.NumRows++ - curMsgSize += sizePerRow - - if curMsgSize >= threshold { - newPack.Msgs = append(newPack.Msgs, curMsg) - delete(result, key) - curMsgSize = 0 - } - - curMsgSizeMap[key] = curMsgSize - } - } - for _, msg := range result { - if msg != nil { - newPack.Msgs = append(newPack.Msgs, msg) - } - } - - return newPack, nil + return result, nil } func (it *insertTask) Execute(ctx context.Context) error { @@ -769,7 +485,7 @@ func (it *insertTask) Execute(ctx context.Context) error { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID())) defer tr.Elapse("done") - collectionName := it.BaseInsertTask.CollectionName + collectionName := it.CollectionName collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) if err != nil { return err @@ -790,16 +506,6 @@ func (it *insertTask) Execute(ctx context.Context) error { it.PartitionID = partitionID tr.Record("get collection id & partition id from cache") - var tsMsg msgstream.TsMsg = &it.BaseInsertTask - it.BaseMsg.Ctx = ctx - msgPack := msgstream.MsgPack{ - BeginTs: it.BeginTs(), - EndTs: it.EndTs(), - Msgs: make([]msgstream.TsMsg, 1), - } - - msgPack.Msgs[0] = tsMsg - stream, err := it.chMgr.getDMLStream(collID) if err != nil { err = it.chMgr.createDMLMsgStream(collID) @@ -817,16 +523,27 @@ func (it *insertTask) Execute(ctx context.Context) error { } tr.Record("get used message stream") - // Assign SegmentID - var pack *msgstream.MsgPack - pack, err = it._assignSegmentID(stream, &msgPack) + channelNames, err := it.chMgr.getVChannels(collID) if err != nil { + log.Error("get vChannels failed", zap.Int64("msgID", it.Base.MsgID), zap.Int64("collectionID", collID), zap.Error(err)) + it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + it.result.Status.Reason = err.Error() return err } + + // assign segmentID for insert data and repack data by segmentID + msgPack, err := it.assignSegmentID(channelNames) + if err != nil { + log.Error("assign segmentID and repack insert data failed", zap.Int64("msgID", it.Base.MsgID), zap.Int64("collectionID", collID), zap.Error(err)) + it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + it.result.Status.Reason = err.Error() + return err + } + log.Debug("assign segmentID for insert data success", zap.Int64("msgID", it.Base.MsgID), zap.Int64("collectionID", collID), zap.String("collection name", it.CollectionName)) tr.Record("assign segment id") tr.Record("sendInsertMsg") - err = stream.Produce(pack) + err = stream.Produce(msgPack) if err != nil { it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError it.result.Status.Reason = err.Error() @@ -835,6 +552,8 @@ func (it *insertTask) Execute(ctx context.Context) error { sendMsgDur := tr.Record("send insert request to message stream") metrics.ProxySendInsertReqLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(sendMsgDur.Milliseconds())) + log.Debug("Proxy Insert Execute done", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName)) + return nil } @@ -4549,8 +4268,6 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { } dt.result.DeleteCnt = int64(len(primaryKeys)) - dt.HashPK(primaryKeys) - rowNum := len(primaryKeys) dt.Timestamps = make([]uint64, rowNum) for index := range dt.Timestamps { @@ -4564,14 +4281,6 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { sp, ctx := trace.StartSpanFromContextWithOperationName(dt.ctx, "Proxy-Delete-Execute") defer sp.Finish() - var tsMsg msgstream.TsMsg = &dt.BaseDeleteTask - msgPack := msgstream.MsgPack{ - BeginTs: dt.BeginTs(), - EndTs: dt.EndTs(), - Msgs: make([]msgstream.TsMsg, 1), - } - msgPack.Msgs[0] = tsMsg - collID := dt.DeleteRequest.CollectionID stream, err := dt.chMgr.getDMLStream(collID) if err != nil { @@ -4588,64 +4297,66 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { return err } } - result := make(map[int32]msgstream.TsMsg) - hashKeys := stream.ComputeProduceChannelIndexes(msgPack.Msgs) - // For each msg, assign PK to different message buckets by hash value of PK. - for i, request := range msgPack.Msgs { - deleteRequest := request.(*msgstream.DeleteMsg) - keys := hashKeys[i] - collectionName := deleteRequest.CollectionName - collectionID := deleteRequest.CollectionID - partitionID := deleteRequest.PartitionID - partitionName := deleteRequest.PartitionName - proxyID := deleteRequest.Base.SourceID - for index, key := range keys { - ts := deleteRequest.Timestamps[index] - pks := deleteRequest.PrimaryKeys[index] - _, ok := result[key] - if !ok { - sliceRequest := internalpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - MsgID: dt.Base.MsgID, - Timestamp: ts, - SourceID: proxyID, - }, - CollectionID: collectionID, - PartitionID: partitionID, - CollectionName: collectionName, - PartitionName: partitionName, - } - deleteMsg := &msgstream.DeleteMsg{ - BaseMsg: msgstream.BaseMsg{ - Ctx: ctx, - }, - DeleteRequest: sliceRequest, - } - result[key] = deleteMsg + + // hash primary keys to channels + channelNames, err := dt.chMgr.getVChannels(collID) + if err != nil { + log.Error("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err)) + dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + dt.result.Status.Reason = err.Error() + return err + } + dt.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames) + + // repack delete msg by dmChannel + result := make(map[uint32]msgstream.TsMsg) + collectionName := dt.CollectionName + collectionID := dt.CollectionID + partitionID := dt.PartitionID + partitionName := dt.PartitionName + proxyID := dt.Base.SourceID + for index, key := range dt.HashValues { + ts := dt.Timestamps[index] + _, ok := result[key] + if !ok { + sliceRequest := internalpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + MsgID: dt.Base.MsgID, + Timestamp: ts, + SourceID: proxyID, + }, + CollectionID: collectionID, + PartitionID: partitionID, + CollectionName: collectionName, + PartitionName: partitionName, } - curMsg := result[key].(*msgstream.DeleteMsg) - curMsg.HashValues = append(curMsg.HashValues, deleteRequest.HashValues[index]) - curMsg.Timestamps = append(curMsg.Timestamps, ts) - curMsg.PrimaryKeys = append(curMsg.PrimaryKeys, pks) + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, + }, + DeleteRequest: sliceRequest, + } + result[key] = deleteMsg } + curMsg := result[key].(*msgstream.DeleteMsg) + curMsg.HashValues = append(curMsg.HashValues, dt.HashValues[index]) + curMsg.Timestamps = append(curMsg.Timestamps, dt.Timestamps[index]) + curMsg.PrimaryKeys = append(curMsg.PrimaryKeys, dt.PrimaryKeys[index]) } - newPack := &msgstream.MsgPack{ - BeginTs: msgPack.BeginTs, - EndTs: msgPack.EndTs, - StartPositions: msgPack.StartPositions, - EndPositions: msgPack.EndPositions, - Msgs: make([]msgstream.TsMsg, 0), + // send delete request to log broker + msgPack := &msgstream.MsgPack{ + BeginTs: dt.BeginTs(), + EndTs: dt.EndTs(), } - for _, msg := range result { if msg != nil { - newPack.Msgs = append(newPack.Msgs, msg) + msgPack.Msgs = append(msgPack.Msgs, msg) } } - err = stream.Produce(newPack) + err = stream.Produce(msgPack) if err != nil { dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError dt.result.Status.Reason = err.Error() @@ -4658,17 +4369,6 @@ func (dt *deleteTask) PostExecute(ctx context.Context) error { return nil } -func (dt *deleteTask) HashPK(pks []int64) { - if len(dt.HashValues) != 0 { - log.Warn("the hashvalues passed through client is not supported now, and will be overwritten") - } - dt.HashValues = make([]uint32, 0, len(pks)) - for _, pk := range pks { - hash, _ := typeutil.Hash32Int64(pk) - dt.HashValues = append(dt.HashValues, hash) - } -} - // CreateAliasTask contains task information of CreateAlias type CreateAliasTask struct { Condition diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 39360575bc..b9416f6528 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -51,6 +51,20 @@ import ( // TODO(dragondriver): add more test cases +const ( + maxTestStringLen = 100 + testBoolField = "bool" + testInt32Field = "int32" + testInt64Field = "int64" + testFloatField = "float" + testDoubleField = "double" + testStringField = "stringField" + testFloatVecField = "fvec" + testBinaryVecField = "bvec" + testVecDim = 128 + testMaxVarCharLength = 100 +) + func constructCollectionSchema( int64Field, floatVecField string, dim int, @@ -93,6 +107,36 @@ func constructCollectionSchema( } } +func constructCollectionSchemaByDataType(collectionName string, fieldName2DataType map[string]schemapb.DataType, primaryFieldName string, autoID bool) *schemapb.CollectionSchema { + fieldsSchema := make([]*schemapb.FieldSchema, 0) + + for fieldName, dataType := range fieldName2DataType { + fieldSchema := &schemapb.FieldSchema{ + Name: fieldName, + DataType: dataType, + } + if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_BinaryVector { + fieldSchema.TypeParams = []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: strconv.Itoa(testVecDim), + }, + } + } + if fieldName == primaryFieldName { + fieldSchema.IsPrimaryKey = true + fieldSchema.AutoID = autoID + } + + fieldsSchema = append(fieldsSchema, fieldSchema) + } + + return &schemapb.CollectionSchema{ + Name: collectionName, + Fields: fieldsSchema, + } +} + func constructCollectionSchemaWithAllType( boolField, int32Field, int64Field, floatField, doubleField string, floatVecField, binaryVecField string, @@ -297,91 +341,6 @@ func constructSearchRequest( } } -func TestGetNumRowsOfScalarField(t *testing.T) { - cases := []struct { - datas interface{} - want uint32 - }{ - {[]bool{}, 0}, - {[]bool{true, false}, 2}, - {[]int32{}, 0}, - {[]int32{1, 2}, 2}, - {[]int64{}, 0}, - {[]int64{1, 2}, 2}, - {[]float32{}, 0}, - {[]float32{1.0, 2.0}, 2}, - {[]float64{}, 0}, - {[]float64{1.0, 2.0}, 2}, - } - - for _, test := range cases { - if got := getNumRowsOfScalarField(test.datas); got != test.want { - t.Errorf("getNumRowsOfScalarField(%v) = %v", test.datas, test.want) - } - } -} - -func TestGetNumRowsOfFloatVectorField(t *testing.T) { - cases := []struct { - fDatas []float32 - dim int64 - want uint32 - errIsNil bool - }{ - {[]float32{}, -1, 0, false}, // dim <= 0 - {[]float32{}, 0, 0, false}, // dim <= 0 - {[]float32{1.0}, 128, 0, false}, // length % dim != 0 - {[]float32{}, 128, 0, true}, - {[]float32{1.0, 2.0}, 2, 1, true}, - {[]float32{1.0, 2.0, 3.0, 4.0}, 2, 2, true}, - } - - for _, test := range cases { - got, err := getNumRowsOfFloatVectorField(test.fDatas, test.dim) - if test.errIsNil { - assert.Equal(t, nil, err) - if got != test.want { - t.Errorf("getNumRowsOfFloatVectorField(%v, %v) = %v, %v", test.fDatas, test.dim, test.want, nil) - } - } else { - assert.NotEqual(t, nil, err) - } - } -} - -func TestGetNumRowsOfBinaryVectorField(t *testing.T) { - cases := []struct { - bDatas []byte - dim int64 - want uint32 - errIsNil bool - }{ - {[]byte{}, -1, 0, false}, // dim <= 0 - {[]byte{}, 0, 0, false}, // dim <= 0 - {[]byte{1.0}, 128, 0, false}, // length % dim != 0 - {[]byte{}, 128, 0, true}, - {[]byte{1.0}, 1, 0, false}, // dim % 8 != 0 - {[]byte{1.0}, 4, 0, false}, // dim % 8 != 0 - {[]byte{1.0, 2.0}, 8, 2, true}, - {[]byte{1.0, 2.0}, 16, 1, true}, - {[]byte{1.0, 2.0, 3.0, 4.0}, 8, 4, true}, - {[]byte{1.0, 2.0, 3.0, 4.0}, 16, 2, true}, - {[]byte{1.0}, 128, 0, false}, // (8*l) % dim != 0 - } - - for _, test := range cases { - got, err := getNumRowsOfBinaryVectorField(test.bDatas, test.dim) - if test.errIsNil { - assert.Equal(t, nil, err) - if got != test.want { - t.Errorf("getNumRowsOfBinaryVectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil) - } - } else { - assert.NotEqual(t, nil, err) - } - } -} - func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { var err error @@ -393,13 +352,18 @@ func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { AutoID: false, Fields: []*schemapb.FieldSchema{}, }, - req: &milvuspb.InsertRequest{ - DbName: "TestInsertTask_checkLengthOfFieldsData", - CollectionName: "TestInsertTask_checkLengthOfFieldsData", - PartitionName: "TestInsertTask_checkLengthOfFieldsData", - FieldsData: nil, + BaseInsertTask: BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + DbName: "TestInsertTask_checkLengthOfFieldsData", + CollectionName: "TestInsertTask_checkLengthOfFieldsData", + PartitionName: "TestInsertTask_checkLengthOfFieldsData", + }, }, } + err = case1.checkLengthOfFieldsData() assert.Equal(t, nil, err) @@ -422,28 +386,32 @@ func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { }, } // passed fields is empty - case2.req = &milvuspb.InsertRequest{} + // case2.BaseInsertTask = BaseInsertTask{ + // InsertRequest: internalpb.InsertRequest{ + // Base: &commonpb.MsgBase{ + // MsgType: commonpb.MsgType_Insert, + // MsgID: 0, + // SourceID: Params.ProxyCfg.ProxyID, + // }, + // }, + // } err = case2.checkLengthOfFieldsData() assert.NotEqual(t, nil, err) // the num of passed fields is less than needed - case2.req = &milvuspb.InsertRequest{ - FieldsData: []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - }, + case2.FieldsData = []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, }, } err = case2.checkLengthOfFieldsData() assert.NotEqual(t, nil, err) // satisfied - case2.req = &milvuspb.InsertRequest{ - FieldsData: []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - }, - { - Type: schemapb.DataType_Int64, - }, + case2.FieldsData = []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + { + Type: schemapb.DataType_Int64, }, } err = case2.checkLengthOfFieldsData() @@ -468,15 +436,13 @@ func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { }, } // passed fields is empty - case3.req = &milvuspb.InsertRequest{} + // case3.req = &milvuspb.InsertRequest{} err = case3.checkLengthOfFieldsData() assert.NotEqual(t, nil, err) // satisfied - case3.req = &milvuspb.InsertRequest{ - FieldsData: []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - }, + case3.FieldsData = []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, }, } err = case3.checkLengthOfFieldsData() @@ -498,180 +464,219 @@ func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { } // passed fields is empty // satisfied - case4.req = &milvuspb.InsertRequest{} + // case4.req = &milvuspb.InsertRequest{} err = case4.checkLengthOfFieldsData() assert.Equal(t, nil, err) } -func TestInsertTask_checkRowNums(t *testing.T) { +func TestInsertTask_CheckAligned(t *testing.T) { var err error // passed NumRows is less than 0 case1 := insertTask{ - req: &milvuspb.InsertRequest{ - NumRows: 0, + BaseInsertTask: BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + NumRows: 0, + }, }, } - err = case1.checkRowNums() - assert.NotEqual(t, nil, err) + err = case1.CheckAligned() + assert.NoError(t, err) // checkLengthOfFieldsData was already checked by TestInsertTask_checkLengthOfFieldsData + boolFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Bool} + int8FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int8} + int16FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int16} + int32FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int32} + int64FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int64} + floatFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float} + doubleFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Double} + floatVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector} + binaryVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BinaryVector} + varCharFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar} + numRows := 20 dim := 128 case2 := insertTask{ + BaseInsertTask: BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + Version: internalpb.InsertDataVersion_ColumnBased, + RowIDs: generateInt64Array(numRows), + Timestamps: generateUint64Array(numRows), + }, + }, schema: &schemapb.CollectionSchema{ Name: "TestInsertTask_checkRowNums", Description: "TestInsertTask_checkRowNums", AutoID: false, Fields: []*schemapb.FieldSchema{ - {DataType: schemapb.DataType_Bool}, - {DataType: schemapb.DataType_Int8}, - {DataType: schemapb.DataType_Int16}, - {DataType: schemapb.DataType_Int32}, - {DataType: schemapb.DataType_Int64}, - {DataType: schemapb.DataType_Float}, - {DataType: schemapb.DataType_Double}, - {DataType: schemapb.DataType_FloatVector}, - {DataType: schemapb.DataType_BinaryVector}, + boolFieldSchema, + int8FieldSchema, + int16FieldSchema, + int32FieldSchema, + int64FieldSchema, + floatFieldSchema, + doubleFieldSchema, + floatVectorFieldSchema, + binaryVectorFieldSchema, + varCharFieldSchema, }, }, } // satisfied - case2.req = &milvuspb.InsertRequest{ - NumRows: uint32(numRows), - FieldsData: []*schemapb.FieldData{ - newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows), - newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows), - newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows), - newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows), - newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows), - newScalarFieldData(schemapb.DataType_Float, "Float", numRows), - newScalarFieldData(schemapb.DataType_Double, "Double", numRows), - newFloatVectorFieldData("FloatVector", numRows, dim), - newBinaryVectorFieldData("BinaryVector", numRows, dim), - }, + case2.NumRows = uint64(numRows) + case2.FieldsData = []*schemapb.FieldData{ + newScalarFieldData(boolFieldSchema, "Bool", numRows), + newScalarFieldData(int8FieldSchema, "Int8", numRows), + newScalarFieldData(int16FieldSchema, "Int16", numRows), + newScalarFieldData(int32FieldSchema, "Int32", numRows), + newScalarFieldData(int64FieldSchema, "Int64", numRows), + newScalarFieldData(floatFieldSchema, "Float", numRows), + newScalarFieldData(doubleFieldSchema, "Double", numRows), + newFloatVectorFieldData("FloatVector", numRows, dim), + newBinaryVectorFieldData("BinaryVector", numRows, dim), + newScalarFieldData(varCharFieldSchema, "VarChar", numRows), } - err = case2.checkRowNums() - assert.Equal(t, nil, err) + err = case2.CheckAligned() + assert.NoError(t, err) // less bool data - case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows/2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) // more bool data - case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows*2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows) + err = case2.CheckAligned() + assert.NoError(t, err) // less int8 data - case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows/2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) // more int8 data - case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows*2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows) + err = case2.CheckAligned() + assert.NoError(t, err) // less int16 data - case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows/2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) // more int16 data - case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows*2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows) + err = case2.CheckAligned() + assert.NoError(t, err) // less int32 data - case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows/2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) // more int32 data - case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows*2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows) + err = case2.CheckAligned() + assert.NoError(t, err) // less int64 data - case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows/2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) // more int64 data - case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows*2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows) + err = case2.CheckAligned() + assert.NoError(t, err) // less float data - case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows/2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) // more float data - case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows*2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows) + err = case2.CheckAligned() + assert.NoError(t, nil, err) // less double data - case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows/2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) // more double data - case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows*2) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows) + err = case2.CheckAligned() + assert.NoError(t, nil, err) // less float vectors - case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim) + err = case2.CheckAligned() + assert.Error(t, err) // more float vectors - case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim) + err = case2.CheckAligned() + assert.NoError(t, err) // less binary vectors - case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim) + err = case2.CheckAligned() + assert.Error(t, err) // more binary vectors - case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim) - err = case2.checkRowNums() - assert.NotEqual(t, nil, err) + case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim) + err = case2.CheckAligned() + assert.Error(t, err) // revert - case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim) - err = case2.checkRowNums() - assert.Equal(t, nil, err) + case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim) + err = case2.CheckAligned() + assert.NoError(t, err) + + // less double data + case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2) + err = case2.CheckAligned() + assert.Error(t, err) + // more double data + case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2) + err = case2.CheckAligned() + assert.Error(t, err) + // revert + case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows) + err = case2.CheckAligned() + assert.NoError(t, err) } func TestTranslateOutputFields(t *testing.T) { @@ -1731,24 +1736,27 @@ func TestSearchTask_all(t *testing.T) { prefix := "TestSearchTask_all" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - boolField := "bool" - int32Field := "int32" - int64Field := "int64" - floatField := "float" - doubleField := "double" - floatVecField := "fvec" - binaryVecField := "bvec" - fieldsLen := len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField}) + dim := 128 - expr := fmt.Sprintf("%s > 0", int64Field) + expr := fmt.Sprintf("%s > 0", testInt64Field) nq := 10 topk := 10 roundDecimal := 3 nprobe := 10 - schema := constructCollectionSchemaWithAllType( - boolField, int32Field, int64Field, floatField, doubleField, - floatVecField, binaryVecField, dim, collectionName) + fieldName2Types := map[string]schemapb.DataType{ + testBoolField: schemapb.DataType_Bool, + testInt32Field: schemapb.DataType_Int32, + testInt64Field: schemapb.DataType_Int64, + testFloatField: schemapb.DataType_Float, + testDoubleField: schemapb.DataType_Double, + testFloatVecField: schemapb.DataType_FloatVector, + } + if enableMultipleVectorFields { + fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector + } + + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -1801,7 +1809,7 @@ func TestSearchTask_all(t *testing.T) { req := constructSearchRequest(dbName, collectionName, expr, - floatVecField, + testFloatVecField, nq, dim, nprobe, topk, roundDecimal) task := &searchTask{ @@ -1866,7 +1874,6 @@ func TestSearchTask_all(t *testing.T) { resultData := &schemapb.SearchResultData{ NumQueries: int64(nq), TopK: int64(topk), - FieldsData: make([]*schemapb.FieldData, fieldsLen), Scores: make([]float32, nq*topk), Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ @@ -1878,109 +1885,10 @@ func TestSearchTask_all(t *testing.T) { Topks: make([]int64, nq), } - resultData.FieldsData[0] = &schemapb.FieldData{ - Type: schemapb.DataType_Bool, - FieldName: boolField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: generateBoolArray(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 0, - } - - resultData.FieldsData[1] = &schemapb.FieldData{ - Type: schemapb.DataType_Int32, - FieldName: int32Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: generateInt32Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 1, - } - - resultData.FieldsData[2] = &schemapb.FieldData{ - Type: schemapb.DataType_Int64, - FieldName: int64Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: generateInt64Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 2, - } - - resultData.FieldsData[3] = &schemapb.FieldData{ - Type: schemapb.DataType_Float, - FieldName: floatField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_FloatData{ - FloatData: &schemapb.FloatArray{ - Data: generateFloat32Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 3, - } - - resultData.FieldsData[4] = &schemapb.FieldData{ - Type: schemapb.DataType_Double, - FieldName: doubleField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_DoubleData{ - DoubleData: &schemapb.DoubleArray{ - Data: generateFloat64Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 4, - } - - resultData.FieldsData[5] = &schemapb.FieldData{ - Type: schemapb.DataType_FloatVector, - FieldName: doubleField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: generateFloatVectors(nq*topk, dim), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 5, - } - - resultData.FieldsData[6] = &schemapb.FieldData{ - Type: schemapb.DataType_BinaryVector, - FieldName: doubleField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_BinaryVector{ - BinaryVector: generateBinaryVectors(nq*topk, dim), - }, - }, - }, - FieldId: common.StartOfUserFieldID + 6, + fieldID := common.StartOfUserFieldID + for fieldName, dataType := range fieldName2Types { + resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk)) + fieldID++ } for i := 0; i < nq; i++ { @@ -2083,24 +1991,26 @@ func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) { prefix := "TestSearchTask_all" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - boolField := "bool" - int32Field := "int32" - int64Field := "int64" - floatField := "float" - doubleField := "double" - floatVecField := "fvec" - binaryVecField := "bvec" - fieldsLen := len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField}) + dim := 128 - expr := fmt.Sprintf("%s > 0", int64Field) + expr := fmt.Sprintf("%s > 0", testInt64Field) nq := 10 topk := 10 roundDecimal := 7 nprobe := 10 - schema := constructCollectionSchemaWithAllType( - boolField, int32Field, int64Field, floatField, doubleField, - floatVecField, binaryVecField, dim, collectionName) + fieldName2Types := map[string]schemapb.DataType{ + testBoolField: schemapb.DataType_Bool, + testInt32Field: schemapb.DataType_Int32, + testInt64Field: schemapb.DataType_Int64, + testFloatField: schemapb.DataType_Float, + testDoubleField: schemapb.DataType_Double, + testFloatVecField: schemapb.DataType_FloatVector, + } + if enableMultipleVectorFields { + fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector + } + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -2153,7 +2063,7 @@ func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) { req := constructSearchRequest(dbName, collectionName, expr, - floatVecField, + testFloatVecField, nq, dim, nprobe, topk, roundDecimal) task := &searchTask{ @@ -2218,7 +2128,6 @@ func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) { resultData := &schemapb.SearchResultData{ NumQueries: int64(nq), TopK: int64(topk), - FieldsData: make([]*schemapb.FieldData, fieldsLen), Scores: make([]float32, nq*topk), Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ @@ -2230,109 +2139,10 @@ func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) { Topks: make([]int64, nq), } - resultData.FieldsData[0] = &schemapb.FieldData{ - Type: schemapb.DataType_Bool, - FieldName: boolField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: generateBoolArray(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 0, - } - - resultData.FieldsData[1] = &schemapb.FieldData{ - Type: schemapb.DataType_Int32, - FieldName: int32Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: generateInt32Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 1, - } - - resultData.FieldsData[2] = &schemapb.FieldData{ - Type: schemapb.DataType_Int64, - FieldName: int64Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: generateInt64Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 2, - } - - resultData.FieldsData[3] = &schemapb.FieldData{ - Type: schemapb.DataType_Float, - FieldName: floatField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_FloatData{ - FloatData: &schemapb.FloatArray{ - Data: generateFloat32Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 3, - } - - resultData.FieldsData[4] = &schemapb.FieldData{ - Type: schemapb.DataType_Double, - FieldName: doubleField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_DoubleData{ - DoubleData: &schemapb.DoubleArray{ - Data: generateFloat64Array(nq * topk), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 4, - } - - resultData.FieldsData[5] = &schemapb.FieldData{ - Type: schemapb.DataType_FloatVector, - FieldName: doubleField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: generateFloatVectors(nq*topk, dim), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 5, - } - - resultData.FieldsData[6] = &schemapb.FieldData{ - Type: schemapb.DataType_BinaryVector, - FieldName: doubleField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_BinaryVector{ - BinaryVector: generateBinaryVectors(nq*topk, dim), - }, - }, - }, - FieldId: common.StartOfUserFieldID + 6, + fieldID := common.StartOfUserFieldID + for fieldName, dataType := range fieldName2Types { + resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk)) + fieldID++ } for i := 0; i < nq; i++ { @@ -3231,21 +3041,23 @@ func TestQueryTask_all(t *testing.T) { prefix := "TestQueryTask_all" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - boolField := "bool" - int32Field := "int32" - int64Field := "int64" - floatField := "float" - doubleField := "double" - floatVecField := "fvec" - binaryVecField := "bvec" - fieldsLen := len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField}) - dim := 128 - expr := fmt.Sprintf("%s > 0", int64Field) + + fieldName2Types := map[string]schemapb.DataType{ + testBoolField: schemapb.DataType_Bool, + testInt32Field: schemapb.DataType_Int32, + testInt64Field: schemapb.DataType_Int64, + testFloatField: schemapb.DataType_Float, + testDoubleField: schemapb.DataType_Double, + testFloatVecField: schemapb.DataType_FloatVector, + } + if enableMultipleVectorFields { + fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector + } + + expr := fmt.Sprintf("%s > 0", testInt64Field) hitNum := 10 - schema := constructCollectionSchemaWithAllType( - boolField, int32Field, int64Field, floatField, doubleField, - floatVecField, binaryVecField, dim, collectionName) + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -3310,7 +3122,7 @@ func TestQueryTask_all(t *testing.T) { CollectionID: collectionID, PartitionIDs: nil, SerializedExprPlan: nil, - OutputFieldsId: make([]int64, fieldsLen), + OutputFieldsId: make([]int64, len(fieldName2Types)), TravelTimestamp: 0, GuaranteeTimestamp: 0, }, @@ -3341,7 +3153,7 @@ func TestQueryTask_all(t *testing.T) { qc: qc, ids: nil, } - for i := 0; i < fieldsLen; i++ { + for i := 0; i < len(fieldName2Types); i++ { task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i) } @@ -3393,115 +3205,15 @@ func TestQueryTask_all(t *testing.T) { }, }, }, - FieldsData: make([]*schemapb.FieldData, fieldsLen), SealedSegmentIDsRetrieved: nil, ChannelIDsRetrieved: nil, GlobalSealedSegmentIDs: nil, } - result1.FieldsData[0] = &schemapb.FieldData{ - Type: schemapb.DataType_Bool, - FieldName: boolField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: generateBoolArray(hitNum), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 0, - } - - result1.FieldsData[1] = &schemapb.FieldData{ - Type: schemapb.DataType_Int32, - FieldName: int32Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: generateInt32Array(hitNum), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 1, - } - - result1.FieldsData[2] = &schemapb.FieldData{ - Type: schemapb.DataType_Int64, - FieldName: int64Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: generateInt64Array(hitNum), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 2, - } - - result1.FieldsData[3] = &schemapb.FieldData{ - Type: schemapb.DataType_Float, - FieldName: floatField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_FloatData{ - FloatData: &schemapb.FloatArray{ - Data: generateFloat32Array(hitNum), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 3, - } - - result1.FieldsData[4] = &schemapb.FieldData{ - Type: schemapb.DataType_Double, - FieldName: doubleField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_DoubleData{ - DoubleData: &schemapb.DoubleArray{ - Data: generateFloat64Array(hitNum), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 4, - } - - result1.FieldsData[5] = &schemapb.FieldData{ - Type: schemapb.DataType_FloatVector, - FieldName: doubleField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: generateFloatVectors(hitNum, dim), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 5, - } - - result1.FieldsData[6] = &schemapb.FieldData{ - Type: schemapb.DataType_BinaryVector, - FieldName: doubleField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_BinaryVector{ - BinaryVector: generateBinaryVectors(hitNum, dim), - }, - }, - }, - FieldId: common.StartOfUserFieldID + 6, + fieldID := common.StartOfUserFieldID + for fieldName, dataType := range fieldName2Types { + result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), hitNum)) + fieldID++ } // send search result @@ -3551,25 +3263,22 @@ func TestTask_all(t *testing.T) { dbName := "" collectionName := prefix + funcutil.GenRandomStr() partitionName := prefix + funcutil.GenRandomStr() - boolField := "bool" - int32Field := "int32" - int64Field := "int64" - floatField := "float" - doubleField := "double" - floatVecField := "fvec" - binaryVecField := "bvec" - var fieldsLen int - fieldsLen = len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField}) + + fieldName2Types := map[string]schemapb.DataType{ + testBoolField: schemapb.DataType_Bool, + testInt32Field: schemapb.DataType_Int32, + testInt64Field: schemapb.DataType_Int64, + testFloatField: schemapb.DataType_Float, + testDoubleField: schemapb.DataType_Double, + //testStringField: schemapb.DataType_String, + testFloatVecField: schemapb.DataType_FloatVector} if enableMultipleVectorFields { - fieldsLen = len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField}) + fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector } - dim := 128 nb := 10 t.Run("create collection", func(t *testing.T) { - schema := constructCollectionSchemaWithAllType( - boolField, int32Field, int64Field, floatField, doubleField, - floatVecField, binaryVecField, dim, collectionName) + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -3648,27 +3357,18 @@ func TestTask_all(t *testing.T) { }, InsertRequest: internalpb.InsertRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: 0, + MsgType: commonpb.MsgType_Insert, + MsgID: 0, + SourceID: Params.ProxyCfg.ProxyID, }, + DbName: dbName, CollectionName: collectionName, PartitionName: partitionName, + NumRows: uint64(nb), + Version: internalpb.InsertDataVersion_ColumnBased, }, }, - req: &milvuspb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: 0, - Timestamp: 0, - SourceID: Params.ProxyCfg.ProxyID, - }, - DbName: dbName, - CollectionName: collectionName, - PartitionName: partitionName, - FieldsData: make([]*schemapb.FieldData, fieldsLen), - HashKeys: hash, - NumRows: uint32(nb), - }, + Condition: NewTaskCondition(ctx), ctx: ctx, result: &milvuspb.MutationResult{ @@ -3694,111 +3394,10 @@ func TestTask_all(t *testing.T) { schema: nil, } - task.req.FieldsData[0] = &schemapb.FieldData{ - Type: schemapb.DataType_Bool, - FieldName: boolField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_BoolData{ - BoolData: &schemapb.BoolArray{ - Data: generateBoolArray(nb), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 0, - } - - task.req.FieldsData[1] = &schemapb.FieldData{ - Type: schemapb.DataType_Int32, - FieldName: int32Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: generateInt32Array(nb), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 1, - } - - task.req.FieldsData[2] = &schemapb.FieldData{ - Type: schemapb.DataType_Int64, - FieldName: int64Field, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: generateInt64Array(nb), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 2, - } - - task.req.FieldsData[3] = &schemapb.FieldData{ - Type: schemapb.DataType_Float, - FieldName: floatField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_FloatData{ - FloatData: &schemapb.FloatArray{ - Data: generateFloat32Array(nb), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 3, - } - - task.req.FieldsData[4] = &schemapb.FieldData{ - Type: schemapb.DataType_Double, - FieldName: doubleField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_DoubleData{ - DoubleData: &schemapb.DoubleArray{ - Data: generateFloat64Array(nb), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 4, - } - - task.req.FieldsData[5] = &schemapb.FieldData{ - Type: schemapb.DataType_FloatVector, - FieldName: floatVecField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: generateFloatVectors(nb, dim), - }, - }, - }, - }, - FieldId: common.StartOfUserFieldID + 5, - } - - if enableMultipleVectorFields { - task.req.FieldsData[6] = &schemapb.FieldData{ - Type: schemapb.DataType_BinaryVector, - FieldName: binaryVecField, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: int64(dim), - Data: &schemapb.VectorField_BinaryVector{ - BinaryVector: generateBinaryVectors(nb, dim), - }, - }, - }, - FieldId: common.StartOfUserFieldID + 6, - } + fieldID := common.StartOfUserFieldID + for fieldName, dataType := range fieldName2Types { + task.FieldsData = append(task.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nb)) + fieldID++ } assert.NoError(t, task.OnEnqueue()) diff --git a/internal/proxy/validate_util.go b/internal/proxy/util.go similarity index 79% rename from internal/proxy/validate_util.go rename to internal/proxy/util.go index 4fc3a9afab..83f2f0dcd4 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/util.go @@ -398,7 +398,7 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error { for i := range schema.Fields { name := schema.Fields[i].Name dType := schema.Fields[i].DataType - isVec := (dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector) + isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector if isVec && vecExist && !enableMultipleVectorFields { return fmt.Errorf( "multiple vector fields is not supported, fields name: %s, %s", @@ -413,3 +413,98 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error { return nil } + +// getPrimaryFieldData get primary field data from all field data inserted from sdk +func getPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schemapb.FieldSchema) (*schemapb.FieldData, error) { + primaryFieldName := primaryFieldSchema.Name + + var primaryFieldData *schemapb.FieldData + for _, field := range datas { + if field.FieldName == primaryFieldName { + if primaryFieldSchema.AutoID { + return nil, fmt.Errorf("autoID field %v does not require data", primaryFieldName) + } + primaryFieldData = field + } + } + + if primaryFieldData == nil { + return nil, fmt.Errorf("can't find data for primary field %v", primaryFieldName) + } + + return primaryFieldData, nil +} + +// parsePrimaryFieldData2IDs get IDs to fill grpc result, for example insert request, delete request etc. +func parsePrimaryFieldData2IDs(fieldData *schemapb.FieldData) (*schemapb.IDs, error) { + primaryData := &schemapb.IDs{} + switch fieldData.Field.(type) { + case *schemapb.FieldData_Scalars: + scalarField := fieldData.GetScalars() + switch scalarField.Data.(type) { + case *schemapb.ScalarField_LongData: + primaryData.IdField = &schemapb.IDs_IntId{ + IntId: scalarField.GetLongData(), + } + case *schemapb.ScalarField_StringData: + primaryData.IdField = &schemapb.IDs_StrId{ + StrId: scalarField.GetStringData(), + } + default: + return nil, errors.New("currently only support DataType Int64 or VarChar as PrimaryField") + } + default: + return nil, errors.New("currently only support vector field as PrimaryField") + } + + return primaryData, nil +} + +// autoGenPrimaryFieldData generate primary data when autoID == true +func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}) (*schemapb.FieldData, error) { + var fieldData schemapb.FieldData + fieldData.FieldName = fieldSchema.Name + fieldData.Type = fieldSchema.DataType + switch data := data.(type) { + case []int64: + if fieldSchema.DataType != schemapb.DataType_Int64 { + return nil, errors.New("the data type of the data and the schema do not match") + } + fieldData.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: data, + }, + }, + }, + } + default: + return nil, errors.New("currently only support autoID for int64 PrimaryField") + } + + return &fieldData, nil +} + +// fillFieldIDBySchema set fieldID to fieldData according FieldSchemas +func fillFieldIDBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error { + if len(columns) != len(schema.GetFields()) { + return fmt.Errorf("len(columns) mismatch the len(fields), len(columns): %d, len(fields): %d", + len(columns), len(schema.GetFields())) + } + fieldName2Schema := make(map[string]*schemapb.FieldSchema) + for _, field := range schema.GetFields() { + fieldName2Schema[field.Name] = field + } + + for _, fieldData := range columns { + if fieldSchema, ok := fieldName2Schema[fieldData.FieldName]; ok { + fieldData.FieldId = fieldSchema.FieldID + fieldData.Type = fieldSchema.DataType + } else { + return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName) + } + } + + return nil +} diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/util_test.go similarity index 94% rename from internal/proxy/validate_util_test.go rename to internal/proxy/util_test.go index 7f9f308219..af09a74b59 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/util_test.go @@ -502,3 +502,28 @@ func TestValidateMultipleVectorFields(t *testing.T) { assert.Error(t, validateMultipleVectorFields(schema3)) } } + +func TestFillFieldIDBySchema(t *testing.T) { + schema := &schemapb.CollectionSchema{} + columns := []*schemapb.FieldData{ + { + FieldName: "TestFillFieldIDBySchema", + }, + } + + // length mismatch + assert.Error(t, fillFieldIDBySchema(columns, schema)) + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "TestFillFieldIDBySchema", + DataType: schemapb.DataType_Int64, + FieldID: 1, + }, + }, + } + assert.NoError(t, fillFieldIDBySchema(columns, schema)) + assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName) + assert.Equal(t, schemapb.DataType_Int64, columns[0].Type) + assert.Equal(t, int64(1), columns[0].FieldId) +} diff --git a/internal/querynode/flow_graph_filter_dm_node.go b/internal/querynode/flow_graph_filter_dm_node.go index 95421543e8..2c49507eec 100644 --- a/internal/querynode/flow_graph_filter_dm_node.go +++ b/internal/querynode/flow_graph_filter_dm_node.go @@ -145,7 +145,7 @@ func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg // filterInvalidInsertMessage would filter out invalid insert messages func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg { - if !msg.CheckAligned() { + if err := msg.CheckAligned(); err != nil { // TODO: what if the messages are misaligned? Here, we ignore those messages and print error log.Warn("Error, misaligned messages detected") return nil diff --git a/internal/querynode/flow_graph_insert_node.go b/internal/querynode/flow_graph_insert_node.go index a34df4102b..570ca7e539 100644 --- a/internal/querynode/flow_graph_insert_node.go +++ b/internal/querynode/flow_graph_insert_node.go @@ -102,15 +102,6 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { // 1. hash insertMessages to insertData for _, insertMsg := range iMsg.insertMessages { - if insertMsg.IsColumnBased() { - var err error - insertMsg.RowData, err = typeutil.TransferColumnBasedDataToRowBasedData(insertMsg.FieldsData) - if err != nil { - log.Error("failed to transfer column-based data to row-based data", zap.Error(err)) - return []Msg{} - } - } - // if loadType is loadCollection, check if partition exists, if not, create partition col, err := iNode.streamingReplica.getCollectionByID(insertMsg.CollectionID) if err != nil { @@ -134,6 +125,15 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { } } + // trans column field data to row data + if insertMsg.IsColumnBased() { + insertMsg.RowData, err = typeutil.TransferColumnBasedDataToRowBasedData(col.schema, insertMsg.FieldsData) + if err != nil { + log.Error("failed to transfer column-based data to row-based data", zap.Error(err)) + return []Msg{} + } + } + iData.insertIDs[insertMsg.SegmentID] = append(iData.insertIDs[insertMsg.SegmentID], insertMsg.RowIDs...) iData.insertTimestamps[insertMsg.SegmentID] = append(iData.insertTimestamps[insertMsg.SegmentID], insertMsg.Timestamps...) // using insertMsg.RowData is valid here, since we have already transferred the column-based data. @@ -345,7 +345,7 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID, wg * // TODO: remove this function to proper file // getPrimaryKeys would get primary keys by insert messages func getPrimaryKeys(msg *msgstream.InsertMsg, streamingReplica ReplicaInterface) ([]int64, error) { - if !msg.CheckAligned() { + if err := msg.CheckAligned(); err != nil { log.Warn("misaligned messages detected") return nil, errors.New("misaligned messages detected") } diff --git a/internal/storage/cwrapper/ColumnType.h b/internal/storage/cwrapper/ColumnType.h index ee5877c7bc..7caf2ca9f3 100644 --- a/internal/storage/cwrapper/ColumnType.h +++ b/internal/storage/cwrapper/ColumnType.h @@ -25,6 +25,7 @@ enum ColumnType : int { FLOAT = 10, DOUBLE = 11, STRING = 20, + VARCHAR = 21, VECTOR_BINARY = 100, VECTOR_FLOAT = 101 }; diff --git a/internal/storage/cwrapper/ParquetWrapper.cpp b/internal/storage/cwrapper/ParquetWrapper.cpp index cc73d7b9af..23a389c061 100644 --- a/internal/storage/cwrapper/ParquetWrapper.cpp +++ b/internal/storage/cwrapper/ParquetWrapper.cpp @@ -76,6 +76,7 @@ CPayloadWriter NewPayloadWriter(int columnType) { p->schema = arrow::schema({arrow::field("val", arrow::float64())}); break; } + case ColumnType::VARCHAR: case ColumnType::STRING : { p->columnType = ColumnType::STRING; p->builder = std::make_shared(); @@ -384,6 +385,7 @@ CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_siz case ColumnType::FLOAT : case ColumnType::DOUBLE : case ColumnType::STRING : + case ColumnType::VARCHAR: case ColumnType::VECTOR_BINARY : case ColumnType::VECTOR_FLOAT : { break; diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index 3591fb662a..1170d1f34c 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -360,7 +360,7 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique return nil, nil, err } writer.AddExtra(originalSizeKey, fmt.Sprintf("%v", singleData.(*DoubleFieldData).GetMemorySize())) - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: for _, singleString := range singleData.(*StringFieldData).Data { err = eventWriter.AddOneStringToPayload(singleString) if err != nil { @@ -416,10 +416,9 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique writer.Close() // stats fields - switch field.DataType { - case schemapb.DataType_Int64: + if field.GetIsPrimaryKey() { statsWriter := &StatsWriter{} - err = statsWriter.StatsInt64(field.FieldID, field.IsPrimaryKey, singleData.(*Int64FieldData).Data) + err = statsWriter.generatePrimaryKeyStats(field.FieldID, field.DataType, singleData) if err != nil { return nil, nil, err } @@ -621,7 +620,7 @@ func (insertCodec *InsertCodec) DeserializeAll(blobs []*Blob) ( totalLength += length doubleFieldData.NumRows = append(doubleFieldData.NumRows, int64(length)) resultData.Data[fieldID] = doubleFieldData - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: if resultData.Data[fieldID] == nil { resultData.Data[fieldID] = &StringFieldData{} } diff --git a/internal/storage/data_sorter.go b/internal/storage/data_sorter.go index 4a3fc5d72e..08da02185f 100644 --- a/internal/storage/data_sorter.go +++ b/internal/storage/data_sorter.go @@ -77,7 +77,7 @@ func (ds *DataSorter) Swap(i, j int) { case schemapb.DataType_Double: data := singleData.(*DoubleFieldData).Data data[i], data[j] = data[j], data[i] - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: data := singleData.(*StringFieldData).Data data[i], data[j] = data[j], data[i] case schemapb.DataType_BinaryVector: diff --git a/internal/storage/payload.go b/internal/storage/payload.go index e1b5e69478..f89b00bf9f 100644 --- a/internal/storage/payload.go +++ b/internal/storage/payload.go @@ -145,7 +145,7 @@ func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error { return errors.New("incorrect data type") } return w.AddDoubleToPayload(val) - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: val, ok := msgs.(string) if !ok { return errors.New("incorrect data type") @@ -387,7 +387,7 @@ func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) switch len(idx) { case 1: switch r.colType { - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: val, err := r.GetOneStringFromPayload(idx[0]) return val, 0, err default: @@ -573,7 +573,7 @@ func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) { } func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) { - if r.colType != schemapb.DataType_String { + if r.colType != schemapb.DataType_String && r.colType != schemapb.DataType_VarChar { return "", errors.New("incorrect data type") } diff --git a/internal/storage/primary_key.go b/internal/storage/primary_key.go new file mode 100644 index 0000000000..96e5513037 --- /dev/null +++ b/internal/storage/primary_key.go @@ -0,0 +1,233 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/milvus-io/milvus/internal/log" +) + +type PrimaryKey interface { + GT(key PrimaryKey) bool + GE(key PrimaryKey) bool + LT(key PrimaryKey) bool + LE(key PrimaryKey) bool + EQ(key PrimaryKey) bool + MarshalJSON() ([]byte, error) + UnmarshalJSON(data []byte) error + SetValue(interface{}) error +} + +type Int64PrimaryKey struct { + Value int64 `json:"pkValue"` +} + +func (ip *Int64PrimaryKey) GT(key PrimaryKey) bool { + pk, ok := key.(*Int64PrimaryKey) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ip.Value > pk.Value { + return true + } + + return false +} + +func (ip *Int64PrimaryKey) GE(key PrimaryKey) bool { + pk, ok := key.(*Int64PrimaryKey) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ip.Value >= pk.Value { + return true + } + + return false +} + +func (ip *Int64PrimaryKey) LT(key PrimaryKey) bool { + pk, ok := key.(*Int64PrimaryKey) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ip.Value < pk.Value { + return true + } + + return false +} + +func (ip *Int64PrimaryKey) LE(key PrimaryKey) bool { + pk, ok := key.(*Int64PrimaryKey) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ip.Value <= pk.Value { + return true + } + + return false +} + +func (ip *Int64PrimaryKey) EQ(key PrimaryKey) bool { + pk, ok := key.(*Int64PrimaryKey) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ip.Value == pk.Value { + return true + } + + return false +} + +func (ip *Int64PrimaryKey) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ip.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ip *Int64PrimaryKey) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ip.Value) + if err != nil { + return err + } + + return nil +} + +func (ip *Int64PrimaryKey) SetValue(data interface{}) error { + value, ok := data.(int64) + if !ok { + return fmt.Errorf("wrong type value when setValue for Int64PrimaryKey") + } + + ip.Value = value + return nil +} + +type StringPrimaryKey struct { + Value string +} + +func (sp *StringPrimaryKey) GT(key PrimaryKey) bool { + pk, ok := key.(*StringPrimaryKey) + if !ok { + log.Warn("type of compared pk is not string") + return false + } + if strings.Compare(sp.Value, pk.Value) > 0 { + return true + } + + return false +} + +func (sp *StringPrimaryKey) GE(key PrimaryKey) bool { + pk, ok := key.(*StringPrimaryKey) + if !ok { + log.Warn("type of compared pk is not string") + return false + } + if strings.Compare(sp.Value, pk.Value) >= 0 { + return true + } + + return false +} + +func (sp *StringPrimaryKey) LT(key PrimaryKey) bool { + pk, ok := key.(*StringPrimaryKey) + if !ok { + log.Warn("type of compared pk is not string") + return false + } + if strings.Compare(sp.Value, pk.Value) < 0 { + return true + } + + return false +} + +func (sp *StringPrimaryKey) LE(key PrimaryKey) bool { + pk, ok := key.(*StringPrimaryKey) + if !ok { + log.Warn("type of compared pk is not string") + return false + } + if strings.Compare(sp.Value, pk.Value) <= 0 { + return true + } + + return false +} + +func (sp *StringPrimaryKey) EQ(key PrimaryKey) bool { + pk, ok := key.(*StringPrimaryKey) + if !ok { + log.Warn("type of compared pk is not string") + return false + } + if strings.Compare(sp.Value, pk.Value) == 0 { + return true + } + + return false +} + +func (sp *StringPrimaryKey) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(sp.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (sp *StringPrimaryKey) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &sp.Value) + if err != nil { + return err + } + + return nil +} + +func (sp *StringPrimaryKey) SetValue(data interface{}) error { + value, ok := data.(string) + if !ok { + return fmt.Errorf("wrong type value when setValue for StringPrimaryKey") + } + + sp.Value = value + return nil +} diff --git a/internal/storage/primary_key_test.go b/internal/storage/primary_key_test.go new file mode 100644 index 0000000000..94210eed3c --- /dev/null +++ b/internal/storage/primary_key_test.go @@ -0,0 +1,82 @@ +package storage + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStringPrimaryKey(t *testing.T) { + pk := &StringPrimaryKey{ + Value: "milvus", + } + + testPk := &StringPrimaryKey{ + Value: "milvus", + } + + // test GE + assert.Equal(t, true, pk.GE(testPk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + // test GT + err := testPk.SetValue("bivlus") + assert.Nil(t, err) + assert.Equal(t, true, pk.GT(testPk)) + + // test LT + err = testPk.SetValue("mivlut") + assert.Nil(t, err) + assert.Equal(t, true, pk.LT(testPk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.Nil(t, err) + + unmarshalledPk := &StringPrimaryKey{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.Nil(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestInt64PrimaryKey(t *testing.T) { + pk := &Int64PrimaryKey{ + Value: 100, + } + + testPk := &Int64PrimaryKey{ + Value: 100, + } + + // test GE + assert.Equal(t, true, pk.GE(testPk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + // test GT + err := testPk.SetValue(int64(10)) + assert.Nil(t, err) + assert.Equal(t, true, pk.GT(testPk)) + + // test LT + err = testPk.SetValue(int64(200)) + assert.Nil(t, err) + assert.Equal(t, true, pk.LT(testPk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.Nil(t, err) + + unmarshalledPk := &Int64PrimaryKey{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.Nil(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} diff --git a/internal/storage/print_binlog.go b/internal/storage/print_binlog.go index 98596a6be0..4c308aa0ed 100644 --- a/internal/storage/print_binlog.go +++ b/internal/storage/print_binlog.go @@ -277,7 +277,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface for i, v := range val { fmt.Printf("\t\t%d : %v\n", i, v) } - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: rows, err := reader.GetPayloadLengthFromReader() if err != nil { return err diff --git a/internal/storage/stats.go b/internal/storage/stats.go index 5a4ec82e78..e655c06c71 100644 --- a/internal/storage/stats.go +++ b/internal/storage/stats.go @@ -20,7 +20,9 @@ import ( "encoding/json" "github.com/bits-and-blooms/bloom/v3" + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/proto/schemapb" ) const ( @@ -29,12 +31,107 @@ const ( maxBloomFalsePositive float64 = 0.005 ) -// Int64Stats contains statistics data for int64 column -type Int64Stats struct { +// PrimaryKeyStats contains statistics data for pk column +type PrimaryKeyStats struct { FieldID int64 `json:"fieldID"` - Max int64 `json:"max"` - Min int64 `json:"min"` + Max int64 `json:"max"` // useless, will delete + Min int64 `json:"min"` //useless, will delete BF *bloom.BloomFilter `json:"bf"` + PkType int64 `json:"pkType"` + MaxPk PrimaryKey `json:"maxPk"` + MinPk PrimaryKey `json:"minPk"` +} + +// UnmarshalJSON unmarshal bytes to PrimaryKeyStats +func (stats *PrimaryKeyStats) UnmarshalJSON(data []byte) error { + var messageMap map[string]*json.RawMessage + err := json.Unmarshal(data, &messageMap) + if err != nil { + return err + } + + err = json.Unmarshal(*messageMap["fieldID"], &stats.FieldID) + if err != nil { + return err + } + + stats.PkType = int64(schemapb.DataType_Int64) + var typeValue int64 + err = json.Unmarshal(*messageMap["pkType"], &typeValue) + if err != nil { + return err + } + // valid pkType + if typeValue > 0 { + stats.PkType = typeValue + } + + switch schemapb.DataType(stats.PkType) { + case schemapb.DataType_Int64: + stats.MaxPk = &Int64PrimaryKey{} + stats.MinPk = &Int64PrimaryKey{} + + // Compatible with versions that only support int64 type primary keys + err = json.Unmarshal(*messageMap["max"], &stats.Max) + if err != nil { + return err + } + err = stats.MaxPk.SetValue(stats.Max) + if err != nil { + return err + } + + err = json.Unmarshal(*messageMap["min"], &stats.Min) + if err != nil { + return err + } + err = stats.MinPk.SetValue(stats.Min) + if err != nil { + return err + } + case schemapb.DataType_VarChar: + stats.MaxPk = &StringPrimaryKey{} + stats.MinPk = &StringPrimaryKey{} + } + + if maxPkMessage, ok := messageMap["maxPk"]; ok && maxPkMessage != nil { + err = json.Unmarshal(*maxPkMessage, stats.MaxPk) + if err != nil { + return err + } + } + + if minPkMessage, ok := messageMap["minPk"]; ok && minPkMessage != nil { + err = json.Unmarshal(*minPkMessage, stats.MinPk) + if err != nil { + return err + } + } + + stats.BF = bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive) + if bfMessage, ok := messageMap["bf"]; ok && bfMessage != nil { + err = stats.BF.UnmarshalJSON(*bfMessage) + if err != nil { + return err + } + } + + return nil +} + +// updatePk update minPk and maxPk value +func (stats *PrimaryKeyStats) updatePk(pk PrimaryKey) { + if stats.MinPk == nil { + stats.MinPk = pk + } else if stats.MinPk.GT(pk) { + stats.MinPk = pk + } + + if stats.MaxPk == nil { + stats.MaxPk = pk + } else if stats.MaxPk.LT(pk) { + stats.MaxPk = pk + } } // StatsWriter writes stats to buffer @@ -47,26 +144,49 @@ func (sw *StatsWriter) GetBuffer() []byte { return sw.buffer } -// StatsInt64 writes Int64Stats from @msgs with @fieldID to @buffer -func (sw *StatsWriter) StatsInt64(fieldID int64, isPrimaryKey bool, msgs []int64) error { - if len(msgs) < 1 { - // return error: msgs must has one element at least - return nil +// generatePrimaryKeyStats writes Int64Stats from @msgs with @fieldID to @buffer +func (sw *StatsWriter) generatePrimaryKeyStats(fieldID int64, pkType schemapb.DataType, msgs FieldData) error { + stats := &PrimaryKeyStats{ + FieldID: fieldID, + PkType: int64(pkType), } - stats := &Int64Stats{ - FieldID: fieldID, - Max: msgs[len(msgs)-1], - Min: msgs[0], - } - if isPrimaryKey { - stats.BF = bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive) + stats.BF = bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive) + switch fieldData := msgs.(type) { + case *Int64FieldData: + data := fieldData.Data + if len(data) < 1 { + // return error: msgs must has one element at least + return nil + } + b := make([]byte, 8) - for _, msg := range msgs { - common.Endian.PutUint64(b, uint64(msg)) + for _, int64Value := range data { + pk := &Int64PrimaryKey{ + Value: int64Value, + } + stats.updatePk(pk) + common.Endian.PutUint64(b, uint64(int64Value)) stats.BF.Add(b) } + case *StringFieldData: + data := fieldData.Data + if len(data) < 1 { + // return error: msgs must has one element at least + return nil + } + + for _, str := range data { + pk := &StringPrimaryKey{ + Value: str, + } + stats.updatePk(pk) + stats.BF.Add([]byte(str)) + } + default: + //TODO:: } + b, err := json.Marshal(stats) if err != nil { return err @@ -86,26 +206,27 @@ func (sr *StatsReader) SetBuffer(buffer []byte) { sr.buffer = buffer } -// GetInt64Stats returns buffer as Int64Stats -func (sr *StatsReader) GetInt64Stats() (*Int64Stats, error) { - stats := &Int64Stats{} +// GetInt64Stats returns buffer as PrimaryKeyStats +func (sr *StatsReader) GetPrimaryKeyStats() (*PrimaryKeyStats, error) { + stats := &PrimaryKeyStats{} err := json.Unmarshal(sr.buffer, &stats) if err != nil { return nil, err } + return stats, nil } -// DeserializeStats deserialize @blobs as []*Int64Stats -func DeserializeStats(blobs []*Blob) ([]*Int64Stats, error) { - results := make([]*Int64Stats, 0, len(blobs)) +// DeserializeStats deserialize @blobs as []*PrimaryKeyStats +func DeserializeStats(blobs []*Blob) ([]*PrimaryKeyStats, error) { + results := make([]*PrimaryKeyStats, 0, len(blobs)) for _, blob := range blobs { if blob.Value == nil { continue } sr := &StatsReader{} sr.SetBuffer(blob.Value) - stats, err := sr.GetInt64Stats() + stats, err := sr.GetPrimaryKeyStats() if err != nil { return nil, err } diff --git a/internal/storage/stats_test.go b/internal/storage/stats_test.go index 95823a87d0..4c5d0f0cbc 100644 --- a/internal/storage/stats_test.go +++ b/internal/storage/stats_test.go @@ -17,33 +17,117 @@ package storage import ( + "encoding/json" "testing" - "github.com/milvus-io/milvus/internal/common" - "github.com/milvus-io/milvus/internal/rootcoord" + "github.com/bits-and-blooms/bloom/v3" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/rootcoord" ) -func TestStatsWriter_StatsInt64(t *testing.T) { - data := []int64{1, 2, 3, 4, 5, 6, 7, 8, 9} +func TestStatsWriter_Int64PrimaryKey(t *testing.T) { + data := &Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } sw := &StatsWriter{} - err := sw.StatsInt64(common.RowIDField, true, data) + err := sw.generatePrimaryKeyStats(common.RowIDField, schemapb.DataType_Int64, data) assert.NoError(t, err) b := sw.GetBuffer() sr := &StatsReader{} sr.SetBuffer(b) - stats, err := sr.GetInt64Stats() + stats, err := sr.GetPrimaryKeyStats() assert.Nil(t, err) - assert.Equal(t, stats.Max, int64(9)) - assert.Equal(t, stats.Min, int64(1)) + maxPk := &Int64PrimaryKey{ + Value: 9, + } + minPk := &Int64PrimaryKey{ + Value: 1, + } + assert.Equal(t, true, stats.MaxPk.EQ(maxPk)) + assert.Equal(t, true, stats.MinPk.EQ(minPk)) buffer := make([]byte, 8) - for _, id := range data { + for _, id := range data.Data { common.Endian.PutUint64(buffer, uint64(id)) assert.True(t, stats.BF.Test(buffer)) } - msgs := []int64{} - err = sw.StatsInt64(rootcoord.RowIDField, true, msgs) + msgs := &Int64FieldData{ + Data: []int64{}, + } + err = sw.generatePrimaryKeyStats(rootcoord.RowIDField, schemapb.DataType_Int64, msgs) assert.Nil(t, err) } + +func TestStatsWriter_StringPrimaryKey(t *testing.T) { + data := &StringFieldData{ + Data: []string{"bc", "ac", "abd", "cd", "milvus"}, + } + sw := &StatsWriter{} + err := sw.generatePrimaryKeyStats(common.RowIDField, schemapb.DataType_VarChar, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &StatsReader{} + sr.SetBuffer(b) + stats, err := sr.GetPrimaryKeyStats() + assert.Nil(t, err) + maxPk := &StringPrimaryKey{ + Value: "milvus", + } + minPk := &StringPrimaryKey{ + Value: "abd", + } + assert.Equal(t, true, stats.MaxPk.EQ(maxPk)) + assert.Equal(t, true, stats.MinPk.EQ(minPk)) + for _, id := range data.Data { + assert.True(t, stats.BF.TestString(id)) + } + + msgs := &Int64FieldData{ + Data: []int64{}, + } + err = sw.generatePrimaryKeyStats(rootcoord.RowIDField, schemapb.DataType_Int64, msgs) + assert.Nil(t, err) +} + +func TestStatsWriter_UpgradePrimaryKey(t *testing.T) { + data := &Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + + stats := &PrimaryKeyStats{ + FieldID: common.RowIDField, + Min: 1, + Max: 9, + BF: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), + } + + b := make([]byte, 8) + for _, int64Value := range data.Data { + common.Endian.PutUint64(b, uint64(int64Value)) + stats.BF.Add(b) + } + blob, err := json.Marshal(stats) + assert.Nil(t, err) + sr := &StatsReader{} + sr.SetBuffer(blob) + unmarshaledStats, err := sr.GetPrimaryKeyStats() + assert.Nil(t, err) + maxPk := &Int64PrimaryKey{ + Value: 9, + } + minPk := &Int64PrimaryKey{ + Value: 1, + } + assert.Equal(t, true, unmarshaledStats.MaxPk.EQ(maxPk)) + assert.Equal(t, true, unmarshaledStats.MinPk.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, unmarshaledStats.BF.Test(buffer)) + } +} diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 12b0f413e9..204051fd12 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -617,6 +617,16 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } fieldData.Data = append(fieldData.Data, srcData...) + idata.Data[field.FieldID] = fieldData + case schemapb.DataType_String, schemapb.DataType_VarChar: + srcData := srcFields[field.FieldID].GetScalars().GetStringData().GetData() + + fieldData := &StringFieldData{ + NumRows: []int64{int64(msg.NumRows)}, + Data: make([]string, 0, len(srcData)), + } + + fieldData.Data = append(fieldData.Data, srcData...) idata.Data[field.FieldID] = fieldData } } @@ -812,7 +822,7 @@ func MergeInsertData(datas ...*InsertData) *InsertData { } // TODO: string type. -func GetPkFromInsertData(collSchema *schemapb.CollectionSchema, data *InsertData) ([]int64, error) { +func GetPkFromInsertData(collSchema *schemapb.CollectionSchema, data *InsertData) (FieldData, error) { helper, err := typeutil.CreateSchemaHelper(collSchema) if err != nil { log.Error("failed to create schema helper", zap.Error(err)) @@ -831,13 +841,21 @@ func GetPkFromInsertData(collSchema *schemapb.CollectionSchema, data *InsertData return nil, errors.New("no primary field found in insert msg") } - realPfData, ok := pfData.(*Int64FieldData) + var realPfData FieldData + switch pf.DataType { + case schemapb.DataType_Int64: + realPfData, ok = pfData.(*Int64FieldData) + case schemapb.DataType_VarChar: + realPfData, ok = pfData.(*StringFieldData) + default: + //TODO + } if !ok { - log.Warn("primary field not in int64 format", zap.Int64("fieldID", pf.FieldID)) - return nil, errors.New("primary field not in int64 format") + log.Warn("primary field not in Int64 or VarChar format", zap.Int64("fieldID", pf.FieldID)) + return nil, errors.New("primary field not in Int64 or VarChar format") } - return realPfData.Data, nil + return realPfData, nil } func boolFieldDataToPbBytes(field *BoolFieldData) ([]byte, error) { diff --git a/internal/storage/utils_test.go b/internal/storage/utils_test.go index 90425fe31d..416a31c189 100644 --- a/internal/storage/utils_test.go +++ b/internal/storage/utils_test.go @@ -1423,7 +1423,7 @@ func TestGetPkFromInsertData(t *testing.T) { } d, err := GetPkFromInsertData(pfSchema, realInt64Data) assert.NoError(t, err) - assert.Equal(t, []int64{1, 2, 3}, d) + assert.Equal(t, []int64{1, 2, 3}, d.(*Int64FieldData).Data) } func Test_boolFieldDataToBytes(t *testing.T) { diff --git a/internal/util/funcutil/func.go b/internal/util/funcutil/func.go index cbb1e918c3..c4101dd10a 100644 --- a/internal/util/funcutil/func.go +++ b/internal/util/funcutil/func.go @@ -24,6 +24,7 @@ import ( "io/ioutil" "net" "net/http" + "reflect" "strconv" "time" @@ -265,3 +266,81 @@ func ConvertChannelName(chanName string, tokenFrom string, tokenTo string) (stri } return "", fmt.Errorf("cannot find token '%s' in '%s'", tokenFrom, chanName) } + +func getNumRowsOfScalarField(datas interface{}) uint64 { + realTypeDatas := reflect.ValueOf(datas) + return uint64(realTypeDatas.Len()) +} + +func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint64, error) { + if dim <= 0 { + return 0, fmt.Errorf("dim(%d) should be greater than 0", dim) + } + l := len(fDatas) + if int64(l)%dim != 0 { + return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim) + } + return uint64(int64(l) / dim), nil +} + +func getNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint64, error) { + if dim <= 0 { + return 0, fmt.Errorf("dim(%d) should be greater than 0", dim) + } + if dim%8 != 0 { + return 0, fmt.Errorf("dim(%d) should divide 8", dim) + } + l := len(bDatas) + if (8*int64(l))%dim != 0 { + return 0, fmt.Errorf("the num(%d) of all bits should divide the dim(%d)", 8*l, dim) + } + return uint64((8 * int64(l)) / dim), nil +} + +// GetNumRowOfFieldData return num rows of the field data +func GetNumRowOfFieldData(fieldData *schemapb.FieldData) (uint64, error) { + var fieldNumRows uint64 + var err error + switch fieldType := fieldData.Field.(type) { + case *schemapb.FieldData_Scalars: + scalarField := fieldData.GetScalars() + switch scalarType := scalarField.Data.(type) { + case *schemapb.ScalarField_BoolData: + fieldNumRows = getNumRowsOfScalarField(scalarField.GetBoolData().Data) + case *schemapb.ScalarField_IntData: + fieldNumRows = getNumRowsOfScalarField(scalarField.GetIntData().Data) + case *schemapb.ScalarField_LongData: + fieldNumRows = getNumRowsOfScalarField(scalarField.GetLongData().Data) + case *schemapb.ScalarField_FloatData: + fieldNumRows = getNumRowsOfScalarField(scalarField.GetFloatData().Data) + case *schemapb.ScalarField_DoubleData: + fieldNumRows = getNumRowsOfScalarField(scalarField.GetDoubleData().Data) + case *schemapb.ScalarField_StringData: + fieldNumRows = getNumRowsOfScalarField(scalarField.GetStringData().Data) + default: + return 0, fmt.Errorf("%s is not supported now", scalarType) + } + case *schemapb.FieldData_Vectors: + vectorField := fieldData.GetVectors() + switch vectorFieldType := vectorField.Data.(type) { + case *schemapb.VectorField_FloatVector: + dim := vectorField.GetDim() + fieldNumRows, err = getNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim) + if err != nil { + return 0, err + } + case *schemapb.VectorField_BinaryVector: + dim := vectorField.GetDim() + fieldNumRows, err = getNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim) + if err != nil { + return 0, err + } + default: + return 0, fmt.Errorf("%s is not supported now", vectorFieldType) + } + default: + return 0, fmt.Errorf("%s is not supported now", fieldType) + } + + return fieldNumRows, nil +} diff --git a/internal/util/funcutil/func_test.go b/internal/util/funcutil/func_test.go index 002d95de69..902bdae7aa 100644 --- a/internal/util/funcutil/func_test.go +++ b/internal/util/funcutil/func_test.go @@ -314,3 +314,88 @@ func Test_ConvertChannelName(t *testing.T) { assert.Nil(t, err) assert.Equal(t, deltaChanName, str) } + +func TestGetNumRowsOfScalarField(t *testing.T) { + cases := []struct { + datas interface{} + want uint64 + }{ + {[]bool{}, 0}, + {[]bool{true, false}, 2}, + {[]int32{}, 0}, + {[]int32{1, 2}, 2}, + {[]int64{}, 0}, + {[]int64{1, 2}, 2}, + {[]float32{}, 0}, + {[]float32{1.0, 2.0}, 2}, + {[]float64{}, 0}, + {[]float64{1.0, 2.0}, 2}, + } + + for _, test := range cases { + if got := getNumRowsOfScalarField(test.datas); got != test.want { + t.Errorf("getNumRowsOfScalarField(%v) = %v", test.datas, test.want) + } + } +} + +func TestGetNumRowsOfFloatVectorField(t *testing.T) { + cases := []struct { + fDatas []float32 + dim int64 + want uint64 + errIsNil bool + }{ + {[]float32{}, -1, 0, false}, // dim <= 0 + {[]float32{}, 0, 0, false}, // dim <= 0 + {[]float32{1.0}, 128, 0, false}, // length % dim != 0 + {[]float32{}, 128, 0, true}, + {[]float32{1.0, 2.0}, 2, 1, true}, + {[]float32{1.0, 2.0, 3.0, 4.0}, 2, 2, true}, + } + + for _, test := range cases { + got, err := getNumRowsOfFloatVectorField(test.fDatas, test.dim) + if test.errIsNil { + assert.Equal(t, nil, err) + if got != test.want { + t.Errorf("getNumRowsOfFloatVectorField(%v, %v) = %v, %v", test.fDatas, test.dim, test.want, nil) + } + } else { + assert.NotEqual(t, nil, err) + } + } +} + +func TestGetNumRowsOfBinaryVectorField(t *testing.T) { + cases := []struct { + bDatas []byte + dim int64 + want uint64 + errIsNil bool + }{ + {[]byte{}, -1, 0, false}, // dim <= 0 + {[]byte{}, 0, 0, false}, // dim <= 0 + {[]byte{1.0}, 128, 0, false}, // length % dim != 0 + {[]byte{}, 128, 0, true}, + {[]byte{1.0}, 1, 0, false}, // dim % 8 != 0 + {[]byte{1.0}, 4, 0, false}, // dim % 8 != 0 + {[]byte{1.0, 2.0}, 8, 2, true}, + {[]byte{1.0, 2.0}, 16, 1, true}, + {[]byte{1.0, 2.0, 3.0, 4.0}, 8, 4, true}, + {[]byte{1.0, 2.0, 3.0, 4.0}, 16, 2, true}, + {[]byte{1.0}, 128, 0, false}, // (8*l) % dim != 0 + } + + for _, test := range cases { + got, err := getNumRowsOfBinaryVectorField(test.bDatas, test.dim) + if test.errIsNil { + assert.Equal(t, nil, err) + if got != test.want { + t.Errorf("getNumRowsOfBinaryVectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil) + } + } else { + assert.NotEqual(t, nil, err) + } + } +} diff --git a/internal/util/typeutil/data_format.go b/internal/util/typeutil/data_format.go index 512ad43c0d..3e9bf23886 100644 --- a/internal/util/typeutil/data_format.go +++ b/internal/util/typeutil/data_format.go @@ -20,6 +20,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "reflect" @@ -102,15 +103,29 @@ func writeToBuffer(w io.Writer, endian binary.ByteOrder, d interface{}) error { return binary.Write(w, endian, d) } -func TransferColumnBasedDataToRowBasedData(columns []*schemapb.FieldData) (rows []*commonpb.Blob, err error) { +func TransferColumnBasedDataToRowBasedData(schema *schemapb.CollectionSchema, columns []*schemapb.FieldData) (rows []*commonpb.Blob, err error) { dTypes := make([]schemapb.DataType, 0, len(columns)) datas := make([][]interface{}, 0, len(columns)) rowNum := 0 + fieldID2FieldData := make(map[int64]schemapb.FieldData) for _, field := range columns { - switch field.Field.(type) { + fieldID2FieldData[field.FieldId] = *field + } + + // reorder field data by schema field orider + for _, field := range schema.Fields { + if field.FieldID == common.RowIDField || field.FieldID == common.TimeStampField { + continue + } + fieldData, ok := fieldID2FieldData[field.FieldID] + if !ok { + return nil, fmt.Errorf("field %s data not exist", field.Name) + } + + switch fieldData.Field.(type) { case *schemapb.FieldData_Scalars: - scalarField := field.GetScalars() + scalarField := fieldData.GetScalars() switch scalarField.Data.(type) { case *schemapb.ScalarField_BoolData: err := appendScalarField(&datas, &rowNum, func() interface{} { @@ -157,7 +172,7 @@ func TransferColumnBasedDataToRowBasedData(columns []*schemapb.FieldData) (rows continue } case *schemapb.FieldData_Vectors: - vectorField := field.GetVectors() + vectorField := fieldData.GetVectors() switch vectorField.Data.(type) { case *schemapb.VectorField_FloatVector: floatVectorFieldData := vectorField.GetFloatVector().Data @@ -184,8 +199,7 @@ func TransferColumnBasedDataToRowBasedData(columns []*schemapb.FieldData) (rows continue } - dTypes = append(dTypes, field.Type) - + dTypes = append(dTypes, field.DataType) } rows = make([]*commonpb.Blob, 0, rowNum) diff --git a/internal/util/typeutil/data_format_test.go b/internal/util/typeutil/data_format_test.go index cf6f13644e..2206d11c83 100644 --- a/internal/util/typeutil/data_format_test.go +++ b/internal/util/typeutil/data_format_test.go @@ -24,13 +24,75 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/schemapb" ) func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { + fieldSchema := []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "bool_field", + DataType: schemapb.DataType_Bool, + }, + { + FieldID: 101, + Name: "int8_field", + DataType: schemapb.DataType_Int8, + }, + { + FieldID: 102, + Name: "int16_field", + DataType: schemapb.DataType_Int16, + }, + { + FieldID: 103, + Name: "int32_field", + DataType: schemapb.DataType_Int32, + }, + { + FieldID: 104, + Name: "int64_field", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 105, + Name: "float32_field", + DataType: schemapb.DataType_Float, + }, + { + FieldID: 106, + Name: "float64_field", + DataType: schemapb.DataType_Double, + }, + { + FieldID: 107, + Name: "float_vector_field", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + }, + { + FieldID: 108, + Name: "binary_vector_field", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "8", + }, + }, + }, + } + columns := []*schemapb.FieldData{ { - Type: schemapb.DataType_Bool, + FieldId: 100, + Type: schemapb.DataType_Bool, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_BoolData{ @@ -42,7 +104,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_Int8, + FieldId: 101, + Type: schemapb.DataType_Int8, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_IntData{ @@ -54,7 +117,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_Int16, + FieldId: 102, + Type: schemapb.DataType_Int16, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_IntData{ @@ -66,7 +130,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_Int32, + FieldId: 103, + Type: schemapb.DataType_Int32, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_IntData{ @@ -78,7 +143,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_Int64, + FieldId: 104, + Type: schemapb.DataType_Int64, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_LongData{ @@ -90,7 +156,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_Float, + FieldId: 105, + Type: schemapb.DataType_Float, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_FloatData{ @@ -102,7 +169,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_Double, + FieldId: 106, + Type: schemapb.DataType_Double, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_DoubleData{ @@ -114,7 +182,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_FloatVector, + FieldId: 107, + Type: schemapb.DataType_FloatVector, Field: &schemapb.FieldData_Vectors{ Vectors: &schemapb.VectorField{ Dim: 1, @@ -127,7 +196,8 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, { - Type: schemapb.DataType_BinaryVector, + FieldId: 108, + Type: schemapb.DataType_BinaryVector, Field: &schemapb.FieldData_Vectors{ Vectors: &schemapb.VectorField{ Dim: 8, @@ -138,7 +208,7 @@ func TestTransferColumnBasedDataToRowBasedData(t *testing.T) { }, }, } - rows, err := TransferColumnBasedDataToRowBasedData(columns) + rows, err := TransferColumnBasedDataToRowBasedData(&schemapb.CollectionSchema{Fields: fieldSchema}, columns) assert.NoError(t, err) assert.Equal(t, 3, len(rows)) if common.Endian == binary.LittleEndian { diff --git a/internal/util/typeutil/hash.go b/internal/util/typeutil/hash.go index 68b06766df..a55308a25c 100644 --- a/internal/util/typeutil/hash.go +++ b/internal/util/typeutil/hash.go @@ -17,12 +17,17 @@ package typeutil import ( + "hash/crc32" "unsafe" - "github.com/milvus-io/milvus/internal/common" "github.com/spaolacci/murmur3" + + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/proto/schemapb" ) +const substringLengthForCRC = 100 + // Hash32Bytes hashing a byte array to uint32 func Hash32Bytes(b []byte) (uint32, error) { h := murmur3.New32() @@ -55,3 +60,37 @@ func Hash32String(s string) (int64, error) { } return int64(v), nil } + +// HashString2Uint32 hashing a string to uint32 +func HashString2Uint32(v string) uint32 { + subString := v + if len(v) > substringLengthForCRC { + subString = v[:substringLengthForCRC] + } + + return crc32.ChecksumIEEE([]byte(subString)) +} + +// HashPK2Channels hash primary keys to channels +func HashPK2Channels(primaryKeys *schemapb.IDs, shardNames []string) []uint32 { + numShard := uint32(len(shardNames)) + var hashValues []uint32 + switch primaryKeys.IdField.(type) { + case *schemapb.IDs_IntId: + pks := primaryKeys.GetIntId().Data + for _, pk := range pks { + value, _ := Hash32Int64(pk) + hashValues = append(hashValues, value%numShard) + } + case *schemapb.IDs_StrId: + pks := primaryKeys.GetStrId().Data + for _, pk := range pks { + hash := HashString2Uint32(pk) + hashValues = append(hashValues, hash%numShard) + } + default: + //TODO:: + } + + return hashValues +} diff --git a/internal/util/typeutil/hash_test.go b/internal/util/typeutil/hash_test.go index 06c0025d0d..4ad2b7f050 100644 --- a/internal/util/typeutil/hash_test.go +++ b/internal/util/typeutil/hash_test.go @@ -21,6 +21,7 @@ import ( "testing" "unsafe" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/stretchr/testify/assert" ) @@ -67,3 +68,29 @@ func TestHash32_String(t *testing.T) { assert.Equal(t, uint32(h), h2) } + +func TestHashPK2Channels(t *testing.T) { + channels := []string{"test1", "test2"} + int64IDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{100, 102, 102, 103, 104}, + }, + }, + } + ret := HashPK2Channels(int64IDs, channels) + assert.Equal(t, 5, len(ret)) + //same pk hash to same channel + assert.Equal(t, ret[1], ret[2]) + + stringIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"ab", "bc", "bc", "abd", "milvus"}, + }, + }, + } + ret = HashPK2Channels(stringIDs, channels) + assert.Equal(t, 5, len(ret)) + assert.Equal(t, ret[1], ret[2]) +} diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go index 1b44eca3d6..919ba85800 100644 --- a/internal/util/typeutil/schema.go +++ b/internal/util/typeutil/schema.go @@ -100,6 +100,33 @@ func EstimateSizePerRecord(schema *schemapb.CollectionSchema) (int, error) { return res, nil } +func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, error) { + res := 0 + for _, fs := range fieldsData { + switch fs.GetType() { + case schemapb.DataType_Bool, schemapb.DataType_Int8: + res++ + case schemapb.DataType_Int16: + res += 2 + case schemapb.DataType_Int32, schemapb.DataType_Float: + res += 4 + case schemapb.DataType_Int64, schemapb.DataType_Double: + res += 8 + case schemapb.DataType_VarChar: + if rowOffset >= len(fs.GetScalars().GetStringData().GetData()) { + return 0, fmt.Errorf("offset out range of field datas") + } + //TODO:: check len(varChar) <= maxLengthPerRow + res += len(fs.GetScalars().GetStringData().Data[rowOffset]) + case schemapb.DataType_BinaryVector: + res += int(fs.GetVectors().GetDim()) + case schemapb.DataType_FloatVector: + res += int(fs.GetVectors().GetDim() * 4) + } + } + return res, nil +} + // SchemaHelper provides methods to get the schema of fields type SchemaHelper struct { schema *schemapb.CollectionSchema @@ -288,6 +315,16 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetDoubleData().Data = append(dstScalar.GetDoubleData().Data, srcScalar.DoubleData.Data[idx]) } + case *schemapb.ScalarField_StringData: + if dstScalar.GetStringData() == nil { + dstScalar.Data = &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{srcScalar.StringData.Data[idx]}, + }, + } + } else { + dstScalar.GetStringData().Data = append(dstScalar.GetStringData().Data, srcScalar.StringData.Data[idx]) + } default: log.Error("Not supported field type", zap.String("field type", fieldData.Type.String())) } @@ -337,17 +374,13 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } } -func FillFieldBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error { - if len(columns) != len(schema.GetFields()) { - return fmt.Errorf("len(columns) mismatch the len(fields), len(columns): %d, len(fields): %d", - len(columns), len(schema.GetFields())) +// GetPrimaryFieldSchema get primary field schema from collection schema +func GetPrimaryFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.FieldSchema, error) { + for _, fieldSchema := range schema.Fields { + if fieldSchema.IsPrimaryKey { + return fieldSchema, nil + } } - for idx, f := range schema.GetFields() { - columns[idx].FieldName = f.Name - columns[idx].Type = f.DataType - columns[idx].FieldId = f.FieldID - } - - return nil + return nil, errors.New("primary field is not found") } diff --git a/internal/util/typeutil/schema_test.go b/internal/util/typeutil/schema_test.go index 3b66fb7c4b..6730a63800 100644 --- a/internal/util/typeutil/schema_test.go +++ b/internal/util/typeutil/schema_test.go @@ -512,29 +512,29 @@ func TestAppendFieldData(t *testing.T) { assert.Equal(t, FloatVector, result[6].GetVectors().GetFloatVector().Data) } -func TestFillFieldBySchema(t *testing.T) { - columns := []*schemapb.FieldData{ - {}, +func TestGetPrimaryFieldSchema(t *testing.T) { + int64Field := &schemapb.FieldSchema{ + FieldID: 1, + Name: "int64Field", + DataType: schemapb.DataType_Int64, } - schema := &schemapb.CollectionSchema{} - // length mismatch - assert.Error(t, FillFieldBySchema(columns, schema)) - columns = []*schemapb.FieldData{ - { - FieldId: 0, - }, + + floatField := &schemapb.FieldSchema{ + FieldID: 2, + Name: "floatField", + DataType: schemapb.DataType_Float, } - schema = &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - Name: "TestFillFieldIDBySchema", - DataType: schemapb.DataType_Int64, - FieldID: 1, - }, - }, + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{int64Field, floatField}, } - assert.NoError(t, FillFieldBySchema(columns, schema)) - assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName) - assert.Equal(t, schemapb.DataType_Int64, columns[0].Type) - assert.Equal(t, int64(1), columns[0].FieldId) + + // no primary field error + _, err := GetPrimaryFieldSchema(schema) + assert.Error(t, err) + + int64Field.IsPrimaryKey = true + primaryField, err := GetPrimaryFieldSchema(schema) + assert.Nil(t, err) + assert.Equal(t, schemapb.DataType_Int64, primaryField.DataType) }