Add datanode import (#16414)

Signed-off-by: godchen0212 <qingxiang.chen@zilliz.com>
pull/16477/head
godchen 2022-04-12 22:19:34 +08:00 committed by GitHub
parent a2011c1f25
commit 4781db8a2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 600 additions and 249 deletions

View File

@ -24,6 +24,8 @@ import (
"sync/atomic"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
@ -33,7 +35,6 @@ import (
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap"
)
const moduleName = "DataCoord"

View File

@ -39,6 +39,7 @@ import (
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
allocator2 "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
@ -46,12 +47,15 @@ import (
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/importutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
@ -787,9 +791,221 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest)
Reason: msgDataNodeIsUnhealthy(Params.DataNodeCfg.NodeID),
}, nil
}
rep, err := node.rootCoord.AllocTimestamp(node.ctx, &rootcoordpb.AllocTimestampRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RequestTSO,
MsgID: 0,
Timestamp: 0,
SourceID: node.NodeID,
},
Count: 1,
})
if rep.Status.ErrorCode != commonpb.ErrorCode_Success || err != nil {
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "DataNode alloc ts failed",
}, nil
}
}
ts := rep.GetTimestamp()
metaService := newMetaService(node.rootCoord, req.GetImportTask().GetCollectionId())
schema, err := metaService.getCollectionSchema(ctx, req.GetImportTask().GetCollectionId(), 0)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, nil
}
idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, Params.DataNodeCfg.NodeID)
importWrapper := importutil.NewImportWrapper(ctx, schema, 2, Params.DataNodeCfg.FlushInsertBufferSize/(1024*1024), idAllocator, node.chunkManager, importFlushReqFunc(node, req, schema, ts))
err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false)
if err != nil {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
}, nil
}
resp := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
return resp, nil
}
type importFlushFunc func(fields map[storage.FieldID]storage.FieldData) error
func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *schemapb.CollectionSchema, ts Timestamp) importFlushFunc {
return func(fields map[storage.FieldID]storage.FieldData) error {
segReqs := []*datapb.SegmentIDRequest{
{
ChannelName: "test-channel",
Count: 1,
CollectionID: req.GetImportTask().GetCollectionId(),
PartitionID: req.GetImportTask().GetCollectionId(),
},
}
segmentIDReq := &datapb.AssignSegmentIDRequest{
NodeID: 0,
PeerRole: typeutil.ProxyRole,
SegmentIDRequests: segReqs,
}
resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq)
if err != nil {
return fmt.Errorf("syncSegmentID Failed:%w", err)
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason)
}
segmentID := resp.SegIDAssignments[0].SegID
var rowNum int
for _, field := range fields {
rowNum = field.RowNum()
break
}
tsFieldData := make([]int64, rowNum)
for i := range tsFieldData {
tsFieldData[i] = int64(ts)
}
fields[common.TimeStampField] = &storage.Int64FieldData{
Data: tsFieldData,
NumRows: []int64{int64(rowNum)},
}
var pkFieldID int64
for _, field := range schema.Fields {
if field.IsPrimaryKey {
pkFieldID = field.GetFieldID()
break
}
}
fields[common.RowIDField] = fields[pkFieldID]
data := BufferData{buffer: &InsertData{
Data: fields,
}}
meta := &etcdpb.CollectionMeta{
ID: req.GetImportTask().GetCollectionId(),
Schema: schema,
}
inCodec := storage.NewInsertCodec(meta)
binLogs, statsBinlogs, err := inCodec.Serialize(req.GetImportTask().GetPartitionId(), segmentID, data.buffer)
if err != nil {
return err
}
var alloc allocatorInterface = newAllocator(node.rootCoord)
start, _, err := alloc.allocIDBatch(uint32(len(binLogs)))
if err != nil {
return err
}
field2Insert := make(map[UniqueID]*datapb.Binlog, len(binLogs))
kvs := make(map[string][]byte, len(binLogs))
field2Logidx := make(map[UniqueID]UniqueID, len(binLogs))
for idx, blob := range binLogs {
fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64)
if err != nil {
log.Error("Flush failed ... cannot parse string to fieldID ..", zap.Error(err))
return err
}
logidx := start + int64(idx)
// no error raise if alloc=false
k := JoinIDPath(req.GetImportTask().GetCollectionId(), req.GetImportTask().GetPartitionId(), segmentID, fieldID, logidx)
key := path.Join(Params.DataNodeCfg.InsertBinlogRootPath, k)
kvs[key] = blob.Value[:]
field2Insert[fieldID] = &datapb.Binlog{
EntriesNum: data.size,
TimestampFrom: 0, //TODO
TimestampTo: 0, //TODO,
LogPath: key,
LogSize: int64(len(blob.Value)),
}
field2Logidx[fieldID] = logidx
}
field2Stats := make(map[UniqueID]*datapb.Binlog)
// write stats binlog
for _, blob := range statsBinlogs {
fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64)
if err != nil {
log.Error("Flush failed ... cannot parse string to fieldID ..", zap.Error(err))
return err
}
logidx := field2Logidx[fieldID]
// no error raise if alloc=false
k := JoinIDPath(req.GetImportTask().GetCollectionId(), req.GetImportTask().GetPartitionId(), segmentID, fieldID, logidx)
key := path.Join(Params.DataNodeCfg.StatsBinlogRootPath, k)
kvs[key] = blob.Value
field2Stats[fieldID] = &datapb.Binlog{
EntriesNum: 0,
TimestampFrom: 0, //TODO
TimestampTo: 0, //TODO,
LogPath: key,
LogSize: int64(len(blob.Value)),
}
}
err = node.chunkManager.MultiWrite(kvs)
if err != nil {
return err
}
var (
fieldInsert []*datapb.FieldBinlog
fieldStats []*datapb.FieldBinlog
)
for k, v := range field2Insert {
fieldInsert = append(fieldInsert, &datapb.FieldBinlog{FieldID: k, Binlogs: []*datapb.Binlog{v}})
}
for k, v := range field2Stats {
fieldStats = append(fieldStats, &datapb.FieldBinlog{FieldID: k, Binlogs: []*datapb.Binlog{v}})
}
req := &datapb.SaveBinlogPathsRequest{
Base: &commonpb.MsgBase{
MsgType: 0, //TODO msg type
MsgID: 0, //TODO msg id
Timestamp: 0, //TODO time stamp
SourceID: Params.DataNodeCfg.NodeID,
},
SegmentID: segmentID,
CollectionID: req.ImportTask.GetCollectionId(),
Field2BinlogPaths: fieldInsert,
Field2StatslogPaths: fieldStats,
Importing: true,
}
err = retry.Do(context.Background(), func() error {
rsp, err := node.dataCoord.SaveBinlogPaths(context.Background(), req)
// should be network issue, return error and retry
if err != nil {
return fmt.Errorf(err.Error())
}
// TODO should retry only when datacoord status is unhealthy
if rsp.ErrorCode != commonpb.ErrorCode_Success {
return fmt.Errorf("data service save bin log path failed, reason = %s", rsp.Reason)
}
return nil
})
if err != nil {
log.Warn("failed to SaveBinlogPaths", zap.Error(err))
return err
}
return nil
}
}

View File

@ -32,6 +32,7 @@ import (
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
@ -80,6 +81,9 @@ func TestDataNode(t *testing.T) {
err = node.Start()
assert.Nil(t, err)
node.chunkManager = storage.NewLocalChunkManager(storage.RootPath("/tmp/lib/milvus"))
Params.DataNodeCfg.NodeID = 1
t.Run("Test WatchDmChannels ", func(t *testing.T) {
emptyNode := &DataNode{}
@ -316,6 +320,34 @@ func TestDataNode(t *testing.T) {
})
t.Run("Test Import", func(t *testing.T) {
content := []byte(`{
"rows":[
{"bool_field": true, "int8_field": 10, "int16_field": 101, "int32_field": 1001, "int64_field": 10001, "float32_field": 3.14, "float64_field": 1.56, "varChar_field": "hello world", "binary_vector_field": [254, 0, 254, 0], "float_vector_field": [1.1, 1.2]},
{"bool_field": false, "int8_field": 11, "int16_field": 102, "int32_field": 1002, "int64_field": 10002, "float32_field": 3.15, "float64_field": 2.56, "varChar_field": "hello world", "binary_vector_field": [253, 0, 253, 0], "float_vector_field": [2.1, 2.2]},
{"bool_field": true, "int8_field": 12, "int16_field": 103, "int32_field": 1003, "int64_field": 10003, "float32_field": 3.16, "float64_field": 3.56, "varChar_field": "hello world", "binary_vector_field": [252, 0, 252, 0], "float_vector_field": [3.1, 3.2]},
{"bool_field": false, "int8_field": 13, "int16_field": 104, "int32_field": 1004, "int64_field": 10004, "float32_field": 3.17, "float64_field": 4.56, "varChar_field": "hello world", "binary_vector_field": [251, 0, 251, 0], "float_vector_field": [4.1, 4.2]},
{"bool_field": true, "int8_field": 14, "int16_field": 105, "int32_field": 1005, "int64_field": 10005, "float32_field": 3.18, "float64_field": 5.56, "varChar_field": "hello world", "binary_vector_field": [250, 0, 250, 0], "float_vector_field": [5.1, 5.2]}
]
}`)
filePath := "import/rows_1.json"
err = node.chunkManager.Write(filePath, content)
assert.NoError(t, err)
req := &datapb.ImportTaskRequest{
ImportTask: &datapb.ImportTask{
CollectionId: 100,
PartitionId: 100,
Files: []string{filePath},
RowBased: true,
},
}
stat, err := node.Import(node.ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stat.ErrorCode)
})
t.Run("Test Import error", func(t *testing.T) {
node.rootCoord = &RootCoordFactory{collectionID: -1}
req := &datapb.ImportTaskRequest{
ImportTask: &datapb.ImportTask{
CollectionId: 100,
@ -324,7 +356,7 @@ func TestDataNode(t *testing.T) {
}
stat, err := node.Import(node.ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stat.ErrorCode)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stat.ErrorCode)
})
t.Run("Test BackGroundGC", func(t *testing.T) {
@ -585,7 +617,6 @@ func TestWatchChannel(t *testing.T) {
exist := node.flowgraphManager.exist("test3")
assert.False(t, exist)
})
}
func TestDataNode_GetComponentStates(t *testing.T) {

View File

@ -173,6 +173,19 @@ type DataCoordFactory struct {
DropVirtualChannelNotSuccess bool
}
func (ds *DataCoordFactory) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) {
return &datapb.AssignSegmentIDResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
SegIDAssignments: []*datapb.SegmentIDAssignment{
{
SegID: 666,
},
},
}, nil
}
func (ds *DataCoordFactory) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult) (*commonpb.Status, error) {
if ds.CompleteCompactionError {
return nil, errors.New("Error")
@ -843,6 +856,12 @@ func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDR
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}}
if in.Count == 12 {
resp.Status.ErrorCode = commonpb.ErrorCode_Success
resp.ID = 1
resp.Count = 12
}
if m.ID == 0 {
resp.Status.Reason = "Zero ID"
return resp, nil

View File

@ -48,8 +48,8 @@ func NewLocalChunkManager(opts ...Option) *LocalChunkManager {
}
}
// GetPath returns the path of local data if exists.
func (lcm *LocalChunkManager) GetPath(filePath string) (string, error) {
// Path returns the path of local data if exists.
func (lcm *LocalChunkManager) Path(filePath string) (string, error) {
if !lcm.Exist(filePath) {
return "", errors.New("local file cannot be found with filePath:" + filePath)
}
@ -57,6 +57,14 @@ func (lcm *LocalChunkManager) GetPath(filePath string) (string, error) {
return absPath, nil
}
func (lcm *LocalChunkManager) Reader(filePath string) (FileReader, error) {
if !lcm.Exist(filePath) {
return nil, errors.New("local file cannot be found with filePath:" + filePath)
}
absPath := path.Join(lcm.localPath, filePath)
return os.Open(absPath)
}
// Write writes the data to local storage.
func (lcm *LocalChunkManager) Write(filePath string, content []byte) error {
absPath := path.Join(lcm.localPath, filePath)
@ -181,7 +189,7 @@ func (lcm *LocalChunkManager) Mmap(filePath string) (*mmap.ReaderAt, error) {
return mmap.Open(path.Clean(absPath))
}
func (lcm *LocalChunkManager) GetSize(filePath string) (int64, error) {
func (lcm *LocalChunkManager) Size(filePath string) (int64, error) {
absPath := path.Join(lcm.localPath, filePath)
fi, err := os.Stat(absPath)
if err != nil {

View File

@ -325,7 +325,7 @@ func TestLocalCM(t *testing.T) {
assert.Error(t, err)
})
t.Run("test GetSize", func(t *testing.T) {
t.Run("test Size", func(t *testing.T) {
testGetSizeRoot := "get_size"
testCM := NewLocalChunkManager(RootPath(localPath))
@ -337,18 +337,18 @@ func TestLocalCM(t *testing.T) {
err := testCM.Write(key, value)
assert.NoError(t, err)
size, err := testCM.GetSize(key)
size, err := testCM.Size(key)
assert.NoError(t, err)
assert.Equal(t, size, int64(len(value)))
key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2")
size, err = testCM.GetSize(key2)
size, err = testCM.Size(key2)
assert.Error(t, err)
assert.Equal(t, int64(0), size)
})
t.Run("test GetPath", func(t *testing.T) {
t.Run("test Path", func(t *testing.T) {
testGetSizeRoot := "get_path"
testCM := NewLocalChunkManager(RootPath(localPath))
@ -360,13 +360,13 @@ func TestLocalCM(t *testing.T) {
err := testCM.Write(key, value)
assert.NoError(t, err)
p, err := testCM.GetPath(key)
p, err := testCM.Path(key)
assert.NoError(t, err)
assert.Equal(t, p, path.Join(localPath, key))
key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2")
p, err = testCM.GetPath(key2)
p, err = testCM.Path(key2)
assert.Error(t, err)
assert.Equal(t, p, "")
})

View File

@ -94,15 +94,23 @@ func newMinioChunkManagerWithConfig(ctx context.Context, c *config) (*MinioChunk
return mcm, nil
}
// GetPath returns the path of minio data if exists.
func (mcm *MinioChunkManager) GetPath(filePath string) (string, error) {
// Path returns the path of minio data if exists.
func (mcm *MinioChunkManager) Path(filePath string) (string, error) {
if !mcm.Exist(filePath) {
return "", errors.New("minio file manage cannot be found with filePath:" + filePath)
}
return filePath, nil
}
func (mcm *MinioChunkManager) GetSize(filePath string) (int64, error) {
// Reader returns the path of minio data if exists.
func (mcm *MinioChunkManager) Reader(filePath string) (FileReader, error) {
if !mcm.Exist(filePath) {
return nil, errors.New("minio file manage cannot be found with filePath:" + filePath)
}
return mcm.Client.GetObject(mcm.ctx, mcm.bucketName, filePath, minio.GetObjectOptions{})
}
func (mcm *MinioChunkManager) Size(filePath string) (int64, error) {
objectInfo, err := mcm.Client.StatObject(mcm.ctx, mcm.bucketName, filePath, minio.StatObjectOptions{})
if err != nil {
return 0, err

View File

@ -354,7 +354,7 @@ func TestMinIOCM(t *testing.T) {
assert.Error(t, err)
})
t.Run("test GetSize", func(t *testing.T) {
t.Run("test Size", func(t *testing.T) {
testGetSizeRoot := path.Join(testMinIOKVRoot, "get_size")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -369,18 +369,18 @@ func TestMinIOCM(t *testing.T) {
err = testCM.Write(key, value)
assert.NoError(t, err)
size, err := testCM.GetSize(key)
size, err := testCM.Size(key)
assert.NoError(t, err)
assert.Equal(t, size, int64(len(value)))
key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2")
size, err = testCM.GetSize(key2)
size, err = testCM.Size(key2)
assert.Error(t, err)
assert.Equal(t, int64(0), size)
})
t.Run("test GetPath", func(t *testing.T) {
t.Run("test Path", func(t *testing.T) {
testGetPathRoot := path.Join(testMinIOKVRoot, "get_path")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -395,13 +395,13 @@ func TestMinIOCM(t *testing.T) {
err = testCM.Write(key, value)
assert.NoError(t, err)
p, err := testCM.GetPath(key)
p, err := testCM.Path(key)
assert.NoError(t, err)
assert.Equal(t, p, key)
key2 := path.Join(testGetPathRoot, "TestMemoryKV_GetSize_key2")
p, err = testCM.GetPath(key2)
p, err = testCM.Path(key2)
assert.Error(t, err)
assert.Equal(t, p, "")
})

View File

@ -12,16 +12,23 @@
package storage
import (
"io"
"golang.org/x/exp/mmap"
)
type FileReader interface {
io.Reader
io.Closer
}
// ChunkManager is to manager chunks.
// Include Read, Write, Remove chunks.
type ChunkManager interface {
// GetPath returns path of @filePath.
GetPath(filePath string) (string, error)
// GetSize returns path of @filePath.
GetSize(filePath string) (int64, error)
// Path returns path of @filePath.
Path(filePath string) (string, error)
// Size returns path of @filePath.
Size(filePath string) (int64, error)
// Write writes @content to @filePath.
Write(filePath string, content []byte) error
// MultiWrite writes multi @content to @filePath.
@ -30,6 +37,8 @@ type ChunkManager interface {
Exist(filePath string) bool
// Read reads @filePath and returns content.
Read(filePath string) ([]byte, error)
// Reader return a reader for @filePath
Reader(filePath string) (FileReader, error)
// MultiRead reads @filePath and returns content.
MultiRead(filePaths []string) ([][]byte, error)
ListWithPrefix(prefix string) ([]string, error)

View File

@ -21,12 +21,13 @@ import (
"io"
"sync"
"go.uber.org/zap"
"golang.org/x/exp/mmap"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/cache"
"go.uber.org/zap"
"golang.org/x/exp/mmap"
)
var (
@ -116,12 +117,12 @@ func (vcm *VectorChunkManager) deserializeVectorFile(filePath string, content []
// GetPath returns the path of vector data. If cached, return local path.
// If not cached return remote path.
func (vcm *VectorChunkManager) GetPath(filePath string) (string, error) {
return vcm.vectorStorage.GetPath(filePath)
func (vcm *VectorChunkManager) Path(filePath string) (string, error) {
return vcm.vectorStorage.Path(filePath)
}
func (vcm *VectorChunkManager) GetSize(filePath string) (int64, error) {
return vcm.vectorStorage.GetSize(filePath)
func (vcm *VectorChunkManager) Size(filePath string) (int64, error) {
return vcm.vectorStorage.Size(filePath)
}
// Write writes the vector data to local cache if cache enabled.
@ -156,7 +157,7 @@ func (vcm *VectorChunkManager) readWithCache(filePath string) ([]byte, error) {
if err != nil {
return nil, err
}
size, err := vcm.cacheStorage.GetSize(filePath)
size, err := vcm.cacheStorage.Size(filePath)
if err != nil {
return nil, err
}
@ -239,6 +240,10 @@ func (vcm *VectorChunkManager) Mmap(filePath string) (*mmap.ReaderAt, error) {
return nil, errors.New("the file mmap has not been cached")
}
func (vcm *VectorChunkManager) Reader(filePath string) (FileReader, error) {
return nil, errors.New("this method has not been implemented")
}
// ReadAt reads specific position data of vector. If cached, it reads from local.
func (vcm *VectorChunkManager) ReadAt(filePath string, off int64, length int64) ([]byte, error) {
if vcm.cacheEnable {

View File

@ -22,11 +22,12 @@ import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
)
func initMeta() *etcdpb.CollectionMeta {
@ -179,13 +180,13 @@ func TestVectorChunkManager_GetPath(t *testing.T) {
key := "1"
err = vcm.Write(key, []byte{1})
assert.Nil(t, err)
pathGet, err := vcm.GetPath(key)
pathGet, err := vcm.Path(key)
assert.Nil(t, err)
assert.Equal(t, pathGet, key)
err = vcm.cacheStorage.Write(key, []byte{1})
assert.Nil(t, err)
pathGet, err = vcm.GetPath(key)
pathGet, err = vcm.Path(key)
assert.Nil(t, err)
assert.Equal(t, pathGet, key)
@ -206,13 +207,13 @@ func TestVectorChunkManager_GetSize(t *testing.T) {
key := "1"
err = vcm.Write(key, []byte{1})
assert.Nil(t, err)
sizeGet, err := vcm.GetSize(key)
sizeGet, err := vcm.Size(key)
assert.Nil(t, err)
assert.EqualValues(t, sizeGet, 1)
err = vcm.cacheStorage.Write(key, []byte{1})
assert.Nil(t, err)
sizeGet, err = vcm.GetSize(key)
sizeGet, err = vcm.Size(key)
assert.Nil(t, err)
assert.EqualValues(t, sizeGet, 1)

View File

@ -4,19 +4,19 @@ import (
"bufio"
"context"
"errors"
"os"
"path"
"strconv"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
const (
@ -29,14 +29,15 @@ type ImportWrapper struct {
cancel context.CancelFunc // for canceling parse process
collectionSchema *schemapb.CollectionSchema // collection schema
shardNum int32 // sharding number of the collection
segmentSize int32 // maximum size of a segment in MB
segmentSize int64 // maximum size of a segment in MB
rowIDAllocator *allocator.IDAllocator // autoid allocator
chunkManager storage.ChunkManager
callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush a segment
callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush a segment
}
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int32,
idAlloc *allocator.IDAllocator, flushFunc func(fields map[string]storage.FieldData) error) *ImportWrapper {
func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64,
idAlloc *allocator.IDAllocator, cm storage.ChunkManager, flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *ImportWrapper {
if collectionSchema == nil {
log.Error("import error: collection schema is nil")
return nil
@ -67,6 +68,7 @@ func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.Collection
segmentSize: segmentSize,
rowIDAllocator: idAlloc,
callFlushFunc: flushFunc,
chunkManager: cm,
}
return wrapper
@ -78,10 +80,10 @@ func (p *ImportWrapper) Cancel() error {
return nil
}
func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldData, msg string, files []string) {
func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) {
stats := make([]zapcore.Field, 0)
for k, v := range fieldsData {
stats = append(stats, zap.Int(k, v.RowNum()))
stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum()))
}
if len(files) > 0 {
@ -112,7 +114,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
if fileType == JSONFileExt {
err := func() error {
file, err := os.Open(filePath)
file, err := p.chunkManager.Reader(filePath)
if err != nil {
return err
}
@ -122,7 +124,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
parser := NewJSONParser(p.ctx, p.collectionSchema)
var consumer *JSONRowConsumer
if !onlyValidate {
flushFunc := func(fields map[string]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths)
return p.callFlushFunc(fields)
}
@ -153,14 +155,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
rowCount := 0
// function to combine column data into fieldsData
combineFunc := func(fields map[string]storage.FieldData) error {
combineFunc := func(fields map[storage.FieldID]storage.FieldData) error {
if len(fields) == 0 {
return nil
}
p.printFieldsDataInfo(fields, "imprort wrapper: combine field data", nil)
fieldNames := make([]string, 0)
fieldNames := make([]storage.FieldID, 0)
for k, v := range fields {
// ignore 0 row field
if v.RowNum() == 0 {
@ -170,12 +172,12 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
// each column should be only combined once
data, ok := fieldsData[k]
if ok && data.RowNum() > 0 {
return errors.New("the field " + k + " is duplicated")
return errors.New("the field " + strconv.FormatInt(k, 10) + " is duplicated")
}
// check the row count. only count non-zero row fields
if rowCount > 0 && rowCount != v.RowNum() {
return errors.New("the field " + k + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount))
return errors.New("the field " + strconv.FormatInt(k, 10) + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount))
}
rowCount = v.RowNum()
@ -195,7 +197,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
if fileType == JSONFileExt {
err := func() error {
file, err := os.Open(filePath)
file, err := p.chunkManager.Reader(filePath)
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
@ -224,17 +226,23 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
return err
}
} else if fileType == NumpyFileExt {
file, err := os.Open(filePath)
file, err := p.chunkManager.Reader(filePath)
if err != nil {
log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
return err
}
defer file.Close()
var id storage.FieldID
for _, field := range p.collectionSchema.Fields {
if field.GetName() == fileName {
id = field.GetFieldID()
}
}
// the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine
flushFunc := func(field storage.FieldData) error {
fields := make(map[string]storage.FieldData)
fields[fileName] = field
fields := make(map[storage.FieldID]storage.FieldData)
fields[id] = field
combineFunc(fields)
return nil
}
@ -325,7 +333,7 @@ func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storag
arr.NumRows[0]++
return nil
}
case schemapb.DataType_String:
case schemapb.DataType_String, schemapb.DataType_VarChar:
return func(src storage.FieldData, n int, target storage.FieldData) error {
arr := target.(*storage.StringFieldData)
arr.Data = append(arr.Data, src.GetRow(n).(string))
@ -336,7 +344,7 @@ func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storag
}
}
func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, files []string) error {
func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, files []string) error {
if len(fieldsData) == 0 {
return errors.New("imprort error: fields data is empty")
}
@ -347,7 +355,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData,
if schema.GetIsPrimaryKey() {
primaryKey = schema
} else {
_, ok := fieldsData[schema.GetName()]
_, ok := fieldsData[schema.GetFieldID()]
if !ok {
return errors.New("imprort error: field " + schema.GetName() + " not provided")
}
@ -363,7 +371,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData,
break
}
primaryData, ok := fieldsData[primaryKey.GetName()]
primaryData, ok := fieldsData[primaryKey.GetFieldID()]
if !ok {
// generate auto id for primary key
if primaryKey.GetAutoID() {
@ -383,7 +391,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData,
}
// prepare segemnts
segmentsData := make([]map[string]storage.FieldData, 0, p.shardNum)
segmentsData := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum)
for i := 0; i < int(p.shardNum); i++ {
segmentData := initSegmentData(p.collectionSchema)
if segmentData == nil {
@ -412,8 +420,8 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData,
for k := 0; k < len(p.collectionSchema.Fields); k++ {
schema := p.collectionSchema.Fields[k]
srcData := fieldsData[schema.GetName()]
targetData := segmentsData[shard][schema.GetName()]
srcData := fieldsData[schema.GetFieldID()]
targetData := segmentsData[shard][schema.GetFieldID()]
appendFunc := appendFunctions[schema.GetName()]
err := appendFunc(srcData, i, targetData)
if err != nil {

View File

@ -1,18 +1,23 @@
package importutil
import (
"bufio"
"bytes"
"context"
"encoding/json"
"os"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/stretchr/testify/assert"
)
const (
@ -20,8 +25,11 @@ const (
)
func Test_NewImportWrapper(t *testing.T) {
f := dependency.NewDefaultFactory(true)
ctx := context.Background()
wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, nil)
cm, err := f.NewVectorStorageChunkManager(ctx)
assert.NoError(t, err)
wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, cm, nil)
assert.Nil(t, wrapper)
schema := &schemapb.CollectionSchema{
@ -39,28 +47,18 @@ func Test_NewImportWrapper(t *testing.T) {
Description: "int64",
DataType: schemapb.DataType_Int64,
})
wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, nil)
wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, cm, nil)
assert.NotNil(t, wrapper)
err := wrapper.Cancel()
err = wrapper.Cancel()
assert.Nil(t, err)
}
func saveFile(t *testing.T, filePath string, content []byte) *os.File {
fp, err := os.Create(filePath)
assert.Nil(t, err)
_, err = fp.Write(content)
assert.Nil(t, err)
return fp
}
func Test_ImportRowBased(t *testing.T) {
f := dependency.NewDefaultFactory(true)
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
cm, err := f.NewVectorStorageChunkManager(ctx)
assert.NoError(t, err)
idAllocator := newIDAllocator(ctx, t)
@ -75,11 +73,12 @@ func Test_ImportRowBased(t *testing.T) {
}`)
filePath := TempFilesPath + "rows_1.json"
fp1 := saveFile(t, filePath, content)
defer fp1.Close()
err = cm.Write(filePath, content)
assert.NoError(t, err)
defer cm.RemoveWithPrefix("")
rowCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -94,7 +93,7 @@ func Test_ImportRowBased(t *testing.T) {
}
// success case
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc)
files := make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, true, false)
@ -109,10 +108,10 @@ func Test_ImportRowBased(t *testing.T) {
}`)
filePath = TempFilesPath + "rows_2.json"
fp2 := saveFile(t, filePath, content)
defer fp2.Close()
err = cm.Write(filePath, content)
assert.NoError(t, err)
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc)
files = make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, true, false)
@ -127,10 +126,11 @@ func Test_ImportRowBased(t *testing.T) {
}
func Test_ImportColumnBased_json(t *testing.T) {
f := dependency.NewDefaultFactory(true)
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
cm, err := f.NewVectorStorageChunkManager(ctx)
assert.NoError(t, err)
defer cm.RemoveWithPrefix("")
idAllocator := newIDAllocator(ctx, t)
@ -160,11 +160,11 @@ func Test_ImportColumnBased_json(t *testing.T) {
}`)
filePath := TempFilesPath + "columns_1.json"
fp1 := saveFile(t, filePath, content)
defer fp1.Close()
err = cm.Write(filePath, content)
assert.NoError(t, err)
rowCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -179,7 +179,7 @@ func Test_ImportColumnBased_json(t *testing.T) {
}
// success case
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc)
files := make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, false, false)
@ -192,10 +192,10 @@ func Test_ImportColumnBased_json(t *testing.T) {
}`)
filePath = TempFilesPath + "rows_2.json"
fp2 := saveFile(t, filePath, content)
defer fp2.Close()
err = cm.Write(filePath, content)
assert.NoError(t, err)
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc)
files = make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, false, false)
@ -209,10 +209,11 @@ func Test_ImportColumnBased_json(t *testing.T) {
}
func Test_ImportColumnBased_numpy(t *testing.T) {
f := dependency.NewDefaultFactory(true)
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
cm, err := f.NewVectorStorageChunkManager(ctx)
assert.NoError(t, err)
defer cm.RemoveWithPrefix("")
idAllocator := newIDAllocator(ctx, t)
@ -230,24 +231,30 @@ func Test_ImportColumnBased_numpy(t *testing.T) {
files := make([]string, 0)
filePath := TempFilesPath + "scalar_fields.json"
fp1 := saveFile(t, filePath, content)
fp1.Close()
err = cm.Write(filePath, content)
assert.NoError(t, err)
files = append(files, filePath)
filePath = TempFilesPath + "field_binary_vector.npy"
bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}}
err = CreateNumpyFile(filePath, bin)
content, err = CreateNumpyData(bin)
assert.Nil(t, err)
log.Debug("content", zap.Any("c", content))
err = cm.Write(filePath, content)
assert.NoError(t, err)
files = append(files, filePath)
filePath = TempFilesPath + "field_float_vector.npy"
flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}}
err = CreateNumpyFile(filePath, flo)
content, err = CreateNumpyData(flo)
assert.Nil(t, err)
log.Debug("content", zap.Any("c", content))
err = cm.Write(filePath, content)
assert.NoError(t, err)
files = append(files, filePath)
rowCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -262,7 +269,7 @@ func Test_ImportColumnBased_numpy(t *testing.T) {
}
// success case
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc)
err = wrapper.Import(files, false, false)
assert.Nil(t, err)
@ -274,10 +281,10 @@ func Test_ImportColumnBased_numpy(t *testing.T) {
}`)
filePath = TempFilesPath + "rows_2.json"
fp2 := saveFile(t, filePath, content)
defer fp2.Close()
err = cm.Write(filePath, content)
assert.NoError(t, err)
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc)
files = make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, false, false)
@ -321,10 +328,11 @@ func perfSchema(dim int) *schemapb.CollectionSchema {
}
func Test_ImportRowBased_perf(t *testing.T) {
f := dependency.NewDefaultFactory(true)
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
cm, err := f.NewVectorStorageChunkManager(ctx)
assert.NoError(t, err)
defer cm.RemoveWithPrefix("")
idAllocator := newIDAllocator(ctx, t)
@ -365,19 +373,22 @@ func Test_ImportRowBased_perf(t *testing.T) {
// generate a json file
filePath := TempFilesPath + "row_perf.json"
func() {
fp, err := os.Create(filePath)
assert.Nil(t, err)
defer fp.Close()
var b bytes.Buffer
bw := bufio.NewWriter(&b)
encoder := json.NewEncoder(fp)
encoder := json.NewEncoder(bw)
err = encoder.Encode(entities)
assert.Nil(t, err)
err = bw.Flush()
assert.NoError(t, err)
err = cm.Write(filePath, b.Bytes())
assert.NoError(t, err)
}()
tr.Record("generate large json file " + filePath)
// parse the json file
parseCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -393,7 +404,7 @@ func Test_ImportRowBased_perf(t *testing.T) {
schema := perfSchema(dim)
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int32(segmentSize), idAllocator, flushFunc)
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, flushFunc)
files := make([]string, 0)
files = append(files, filePath)
err = wrapper.Import(files, true, false)
@ -404,10 +415,11 @@ func Test_ImportRowBased_perf(t *testing.T) {
}
func Test_ImportColumnBased_perf(t *testing.T) {
f := dependency.NewDefaultFactory(true)
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
cm, err := f.NewVectorStorageChunkManager(ctx)
assert.NoError(t, err)
defer cm.RemoveWithPrefix("")
idAllocator := newIDAllocator(ctx, t)
@ -449,15 +461,17 @@ func Test_ImportColumnBased_perf(t *testing.T) {
// generate json files
saveFileFunc := func(filePath string, data interface{}) error {
fp, err := os.Create(filePath)
if err != nil {
return err
}
defer fp.Close()
var b bytes.Buffer
bw := bufio.NewWriter(&b)
encoder := json.NewEncoder(fp)
encoder := json.NewEncoder(bw)
err = encoder.Encode(data)
return err
assert.Nil(t, err)
err = bw.Flush()
assert.NoError(t, err)
err = cm.Write(filePath, b.Bytes())
assert.NoError(t, err)
return nil
}
filePath1 := TempFilesPath + "ids.json"
@ -472,7 +486,7 @@ func Test_ImportColumnBased_perf(t *testing.T) {
// parse the json file
parseCount := 0
flushFunc := func(fields map[string]storage.FieldData) error {
flushFunc := func(fields map[storage.FieldID]storage.FieldData) error {
count := 0
for _, data := range fields {
assert.Less(t, 0, data.RowNum())
@ -488,7 +502,7 @@ func Test_ImportColumnBased_perf(t *testing.T) {
schema := perfSchema(dim)
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int32(segmentSize), idAllocator, flushFunc)
wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, flushFunc)
files := make([]string, 0)
files = append(files, filePath1)
files = append(files, filePath2)

View File

@ -5,22 +5,23 @@ import (
"fmt"
"strconv"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
// interface to process rows data
type JSONRowHandler interface {
Handle(rows []map[string]interface{}) error
Handle(rows []map[storage.FieldID]interface{}) error
}
// interface to process column data
type JSONColumnHandler interface {
Handle(columns map[string][]interface{}) error
Handle(columns map[storage.FieldID][]interface{}) error
}
// method to get dimension of vecotor field
@ -49,7 +50,7 @@ type Validator struct {
}
// method to construct valiator functions
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[string]*Validator) error {
func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error {
if collectionSchema == nil {
return errors.New("collection schema is nil")
}
@ -70,13 +71,13 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
validators[schema.GetName()] = &Validator{}
validators[schema.GetName()].primaryKey = schema.GetIsPrimaryKey()
validators[schema.GetName()].autoID = schema.GetAutoID()
validators[schema.GetFieldID()] = &Validator{}
validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey()
validators[schema.GetFieldID()].autoID = schema.GetAutoID()
switch schema.DataType {
case schemapb.DataType_Bool:
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case bool:
return nil
@ -87,55 +88,55 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(bool)
field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value)
field.(*storage.BoolFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Float:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := float32(obj.(float64))
field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value)
field.(*storage.FloatFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Double:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(float64)
field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value)
field.(*storage.DoubleFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int8:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int8(obj.(float64))
field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value)
field.(*storage.Int8FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int16:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int16(obj.(float64))
field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value)
field.(*storage.Int16FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int32:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int32(obj.(float64))
field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value)
field.(*storage.Int32FieldData).NumRows[0]++
return nil
}
case schemapb.DataType_Int64:
validators[schema.GetName()].validateFunc = numericValidator
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].validateFunc = numericValidator
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := int64(obj.(float64))
field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value)
field.(*storage.Int64FieldData).NumRows[0]++
@ -146,9 +147,9 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
if err != nil {
return err
}
validators[schema.GetName()].dimension = dim
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt)*8 != dim {
@ -175,7 +176,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := byte(arr[i].(float64))
@ -190,9 +191,9 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
if err != nil {
return err
}
validators[schema.GetName()].dimension = dim
validators[schema.GetFieldID()].dimension = dim
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch vt := obj.(type) {
case []interface{}:
if len(vt) != dim {
@ -213,7 +214,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
arr := obj.([]interface{})
for i := 0; i < len(arr); i++ {
value := float32(arr[i].(float64))
@ -222,8 +223,8 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
field.(*storage.FloatVectorFieldData).NumRows[0]++
return nil
}
case schemapb.DataType_String:
validators[schema.GetName()].validateFunc = func(obj interface{}) error {
case schemapb.DataType_String, schemapb.DataType_VarChar:
validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error {
switch obj.(type) {
case string:
return nil
@ -234,7 +235,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
}
}
validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error {
validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error {
value := obj.(string)
field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value)
field.(*storage.StringFieldData).NumRows[0]++
@ -250,14 +251,14 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[
// row-based json format validator class
type JSONRowValidator struct {
downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer
validators map[string]*Validator // validators for each field
rowCounter int64 // how many rows have been validated
downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer
validators map[storage.FieldID]*Validator // validators for each field
rowCounter int64 // how many rows have been validated
}
func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream JSONRowHandler) *JSONRowValidator {
v := &JSONRowValidator{
validators: make(map[string]*Validator),
validators: make(map[storage.FieldID]*Validator),
downstream: downstream,
rowCounter: 0,
}
@ -270,7 +271,7 @@ func (v *JSONRowValidator) ValidateCount() int64 {
return v.rowCounter
}
func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error {
func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row validator is not initialized")
}
@ -286,14 +287,14 @@ func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error {
for i := 0; i < len(rows); i++ {
row := rows[i]
for name, validator := range v.validators {
for id, validator := range v.validators {
if validator.primaryKey && validator.autoID {
// auto-generated primary key, ignore
continue
}
value, ok := row[name]
value, ok := row[id]
if !ok {
return errors.New("JSON row validator: field " + name + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
return errors.New("JSON row validator: fieldID " + strconv.FormatInt(id, 10) + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10))
}
if err := validator.validateFunc(value); err != nil {
@ -313,27 +314,27 @@ func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error {
// column-based json format validator class
type JSONColumnValidator struct {
downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer
validators map[string]*Validator // validators for each field
rowCounter map[string]int64 // row count of each field
downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer
validators map[storage.FieldID]*Validator // validators for each field
rowCounter map[storage.FieldID]int64 // row count of each field
}
func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) *JSONColumnValidator {
v := &JSONColumnValidator{
validators: make(map[string]*Validator),
validators: make(map[storage.FieldID]*Validator),
downstream: downstream,
rowCounter: make(map[string]int64),
rowCounter: make(map[storage.FieldID]int64),
}
initValidators(schema, v.validators)
return v
}
func (v *JSONColumnValidator) ValidateCount() map[string]int64 {
func (v *JSONColumnValidator) ValidateCount() map[storage.FieldID]int64 {
return v.rowCounter
}
func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error {
func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column validator is not initialized")
}
@ -346,7 +347,7 @@ func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error {
if rowCount == -1 {
rowCount = counter
} else if rowCount != counter {
return errors.New("JSON column validator: the field " + k + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields " + strconv.Itoa(int(rowCount)))
return errors.New("JSON column validator: the field " + strconv.FormatInt(k, 10) + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields " + strconv.Itoa(int(rowCount)))
}
}
@ -383,74 +384,74 @@ func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error {
// row-based json format consumer class
type JSONRowConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
rowIDAllocator *allocator.IDAllocator // autoid allocator
validators map[string]*Validator // validators for each field
rowCounter int64 // how many rows have been consumed
shardNum int32 // sharding number of the collection
segmentsData []map[string]storage.FieldData // in-memory segments data
segmentSize int32 // maximum size of a segment in MB
primaryKey string // name of primary key
collectionSchema *schemapb.CollectionSchema // collection schema
rowIDAllocator *allocator.IDAllocator // autoid allocator
validators map[storage.FieldID]*Validator // validators for each field
rowCounter int64 // how many rows have been consumed
shardNum int32 // sharding number of the collection
segmentsData []map[storage.FieldID]storage.FieldData // in-memory segments data
segmentSize int64 // maximum size of a segment in MB
primaryKey storage.FieldID // name of primary key
callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment
callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush segment
}
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[string]storage.FieldData {
segmentData := make(map[string]storage.FieldData)
func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData {
segmentData := make(map[storage.FieldID]storage.FieldData)
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
switch schema.DataType {
case schemapb.DataType_Bool:
segmentData[schema.GetName()] = &storage.BoolFieldData{
segmentData[schema.GetFieldID()] = &storage.BoolFieldData{
Data: make([]bool, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Float:
segmentData[schema.GetName()] = &storage.FloatFieldData{
segmentData[schema.GetFieldID()] = &storage.FloatFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Double:
segmentData[schema.GetName()] = &storage.DoubleFieldData{
segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{
Data: make([]float64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int8:
segmentData[schema.GetName()] = &storage.Int8FieldData{
segmentData[schema.GetFieldID()] = &storage.Int8FieldData{
Data: make([]int8, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int16:
segmentData[schema.GetName()] = &storage.Int16FieldData{
segmentData[schema.GetFieldID()] = &storage.Int16FieldData{
Data: make([]int16, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int32:
segmentData[schema.GetName()] = &storage.Int32FieldData{
segmentData[schema.GetFieldID()] = &storage.Int32FieldData{
Data: make([]int32, 0),
NumRows: []int64{0},
}
case schemapb.DataType_Int64:
segmentData[schema.GetName()] = &storage.Int64FieldData{
segmentData[schema.GetFieldID()] = &storage.Int64FieldData{
Data: make([]int64, 0),
NumRows: []int64{0},
}
case schemapb.DataType_BinaryVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetName()] = &storage.BinaryVectorFieldData{
segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{
Data: make([]byte, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_FloatVector:
dim, _ := getFieldDimension(schema)
segmentData[schema.GetName()] = &storage.FloatVectorFieldData{
segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{
Data: make([]float32, 0),
NumRows: []int64{0},
Dim: dim,
}
case schemapb.DataType_String:
segmentData[schema.GetName()] = &storage.StringFieldData{
case schemapb.DataType_String, schemapb.DataType_VarChar:
segmentData[schema.GetFieldID()] = &storage.StringFieldData{
Data: make([]string, 0),
NumRows: []int64{0},
}
@ -463,8 +464,8 @@ func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[string]sto
return segmentData
}
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int32,
flushFunc func(fields map[string]storage.FieldData) error) *JSONRowConsumer {
func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int64,
flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *JSONRowConsumer {
if collectionSchema == nil {
log.Error("JSON row consumer: collection schema is nil")
return nil
@ -473,16 +474,17 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
v := &JSONRowConsumer{
collectionSchema: collectionSchema,
rowIDAllocator: idAlloc,
validators: make(map[string]*Validator),
validators: make(map[storage.FieldID]*Validator),
shardNum: shardNum,
segmentSize: segmentSize,
rowCounter: 0,
primaryKey: -1,
callFlushFunc: flushFunc,
}
initValidators(collectionSchema, v.validators)
v.segmentsData = make([]map[string]storage.FieldData, 0, shardNum)
v.segmentsData = make([]map[storage.FieldID]storage.FieldData, 0, shardNum)
for i := 0; i < int(shardNum); i++ {
segmentData := initSegmentData(collectionSchema)
if segmentData == nil {
@ -494,12 +496,12 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
v.primaryKey = schema.GetName()
v.primaryKey = schema.GetFieldID()
break
}
}
// primary key not found
if v.primaryKey == "" {
if v.primaryKey == -1 {
log.Error("JSON row consumer: collection schema has no primary key")
return nil
}
@ -544,7 +546,7 @@ func (v *JSONRowConsumer) flush(force bool) error {
return nil
}
func (v *JSONRowConsumer) Handle(rows []map[string]interface{}) error {
func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON row consumer is not initialized")
}
@ -614,23 +616,23 @@ func (v *JSONRowConsumer) Handle(rows []map[string]interface{}) error {
// column-based json format consumer class
type JSONColumnConsumer struct {
collectionSchema *schemapb.CollectionSchema // collection schema
validators map[string]*Validator // validators for each field
fieldsData map[string]storage.FieldData // in-memory fields data
primaryKey string // name of primary key
collectionSchema *schemapb.CollectionSchema // collection schema
validators map[storage.FieldID]*Validator // validators for each field
fieldsData map[storage.FieldID]storage.FieldData // in-memory fields data
primaryKey storage.FieldID // name of primary key
callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment
callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush segment
}
func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema,
flushFunc func(fields map[string]storage.FieldData) error) *JSONColumnConsumer {
flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *JSONColumnConsumer {
if collectionSchema == nil {
return nil
}
v := &JSONColumnConsumer{
collectionSchema: collectionSchema,
validators: make(map[string]*Validator),
validators: make(map[storage.FieldID]*Validator),
callFlushFunc: flushFunc,
}
initValidators(collectionSchema, v.validators)
@ -639,7 +641,7 @@ func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema,
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
if schema.GetIsPrimaryKey() {
v.primaryKey = schema.GetName()
v.primaryKey = schema.GetFieldID()
break
}
}
@ -650,9 +652,9 @@ func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema,
func (v *JSONColumnConsumer) flush() error {
// check row count, should be equal
rowCount := 0
for name, field := range v.fieldsData {
for id, field := range v.fieldsData {
// skip the autoid field
if name == v.primaryKey && v.validators[v.primaryKey].autoID {
if id == v.primaryKey && v.validators[v.primaryKey].autoID {
continue
}
cnt := field.RowNum()
@ -665,7 +667,7 @@ func (v *JSONColumnConsumer) flush() error {
if rowCount == 0 {
rowCount = cnt
} else if rowCount != cnt {
return errors.New("JSON column consumer: " + name + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount))
return errors.New("JSON column consumer: " + strconv.FormatInt(id, 10) + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount))
}
}
@ -678,7 +680,7 @@ func (v *JSONColumnConsumer) flush() error {
return v.callFlushFunc(v.fieldsData)
}
func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error {
func (v *JSONColumnConsumer) Handle(columns map[storage.FieldID][]interface{}) error {
if v.validators == nil || len(v.validators) == 0 {
return errors.New("JSON column consumer is not initialized")
}
@ -691,10 +693,10 @@ func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error {
}
// consume columns data
for name, values := range columns {
validator, ok := v.validators[name]
for id, values := range columns {
validator, ok := v.validators[id]
if !ok {
// not a valid field name
// not a valid field id
break
}
@ -705,8 +707,8 @@ func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error {
// convert and consume data
for i := 0; i < len(values); i++ {
if err := validator.convertFunc(values[i], v.fieldsData[name]); err != nil {
return errors.New("JSON column consumer: " + err.Error() + " of field " + name)
if err := validator.convertFunc(values[i], v.fieldsData[id]); err != nil {
return errors.New("JSON column consumer: " + err.Error() + " of field " + strconv.FormatInt(id, 10))
}
}
}

View File

@ -5,12 +5,13 @@ import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/stretchr/testify/assert"
)
type mockIDAllocator struct {
@ -68,17 +69,23 @@ func Test_GetFieldDimension(t *testing.T) {
}
func Test_InitValidators(t *testing.T) {
validators := make(map[string]*Validator)
validators := make(map[storage.FieldID]*Validator)
err := initValidators(nil, validators)
assert.NotNil(t, err)
schema := sampleSchema()
// success case
err = initValidators(sampleSchema(), validators)
err = initValidators(schema, validators)
assert.Nil(t, err)
assert.Equal(t, len(sampleSchema().Fields), len(validators))
assert.Equal(t, len(schema.Fields), len(validators))
name2ID := make(map[string]storage.FieldID)
for _, field := range schema.Fields {
name2ID[field.GetName()] = field.GetFieldID()
}
checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) {
v, ok := validators[funcName]
id := name2ID[funcName]
v, ok := validators[id]
assert.True(t, ok)
err = v.validateFunc(validVal)
assert.Nil(t, err)
@ -127,7 +134,7 @@ func Test_InitValidators(t *testing.T) {
checkFunc("field_float_vector", validVal, invalidVal)
// error cases
schema := &schemapb.CollectionSchema{
schema = &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
AutoID: true,
@ -144,7 +151,7 @@ func Test_InitValidators(t *testing.T) {
},
})
validators = make(map[string]*Validator)
validators = make(map[storage.FieldID]*Validator)
err = initValidators(schema, validators)
assert.NotNil(t, err)
@ -308,7 +315,7 @@ func Test_JSONRowConsumer(t *testing.T) {
var callTime int32
var totalCount int
consumeFunc := func(fields map[string]storage.FieldData) error {
consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error {
callTime++
rowCount := 0
for _, data := range fields {
@ -370,7 +377,7 @@ func Test_JSONColumnConsumer(t *testing.T) {
callTime := 0
rowCount := 0
consumeFunc := func(fields map[string]storage.FieldData) error {
consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error {
callTime++
for _, data := range fields {
if rowCount == 0 {

View File

@ -9,6 +9,7 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
)
const (
@ -19,23 +20,27 @@ const (
)
type JSONParser struct {
ctx context.Context // for canceling parse process
bufSize int64 // max rows in a buffer
fields map[string]int64 // fields need to be parsed
ctx context.Context // for canceling parse process
bufSize int64 // max rows in a buffer
fields map[string]int64 // fields need to be parsed
name2FieldID map[string]storage.FieldID
}
// NewJSONParser helper function to create a JSONParser
func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser {
fields := make(map[string]int64)
name2FieldID := make(map[string]storage.FieldID)
for i := 0; i < len(collectionSchema.Fields); i++ {
schema := collectionSchema.Fields[i]
fields[schema.GetName()] = 0
name2FieldID[schema.GetName()] = schema.GetFieldID()
}
parser := &JSONParser{
ctx: ctx,
bufSize: 4096,
fields: fields,
ctx: ctx,
bufSize: 4096,
fields: fields,
name2FieldID: name2FieldID,
}
return parser
@ -87,7 +92,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
}
// read buffer
buf := make([]map[string]interface{}, 0, BufferSize)
buf := make([]map[storage.FieldID]interface{}, 0, BufferSize)
for dec.More() {
var value interface{}
if err := dec.Decode(&value); err != nil {
@ -101,7 +106,11 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
return p.logError("JSON parse: invalid JSON format, each row should be a key-value map")
}
row := value.(map[string]interface{})
row := make(map[storage.FieldID]interface{})
stringMap := value.(map[string]interface{})
for k, v := range stringMap {
row[p.name2FieldID[k]] = v
}
buf = append(buf, row)
if len(buf) >= int(p.bufSize) {
@ -110,7 +119,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error {
}
// clear the buffer
buf = make([]map[string]interface{}, 0, BufferSize)
buf = make([]map[storage.FieldID]interface{}, 0, BufferSize)
}
}
@ -185,9 +194,10 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error
return p.logError("JSON parse: invalid column-based JSON format, each field should begin with '['")
}
id := p.name2FieldID[key]
// read buffer
buf := make(map[string][]interface{})
buf[key] = make([]interface{}, 0, BufferSize)
buf := make(map[storage.FieldID][]interface{})
buf[id] = make([]interface{}, 0, BufferSize)
for dec.More() {
var value interface{}
if err := dec.Decode(&value); err != nil {
@ -198,19 +208,19 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error
continue
}
buf[key] = append(buf[key], value)
if len(buf[key]) >= int(p.bufSize) {
buf[id] = append(buf[id], value)
if len(buf[id]) >= int(p.bufSize) {
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}
// clear the buffer
buf[key] = make([]interface{}, 0, BufferSize)
buf[id] = make([]interface{}, 0, BufferSize)
}
}
// some values in buffer not parsed, parse them
if len(buf[key]) > 0 {
if len(buf[id]) > 0 {
if err = handler.Handle(buf); err != nil {
return p.logError(err.Error())
}

View File

@ -1,6 +1,7 @@
package importutil
import (
"bytes"
"encoding/binary"
"errors"
"io"
@ -25,6 +26,16 @@ func CreateNumpyFile(path string, data interface{}) error {
return nil
}
func CreateNumpyData(data interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
err := npyio.Write(buf, data)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// a class to expand other numpy lib ability
// we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio
// the npyio lib read data one by one, the performance is poor, we expand the read methods

View File

@ -5,11 +5,12 @@ import (
"os"
"testing"
"github.com/sbinet/npyio/npy"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/sbinet/npyio/npy"
"github.com/stretchr/testify/assert"
)
func Test_NewNumpyParser(t *testing.T) {