mirror of https://github.com/milvus-io/milvus.git
				
				
				
			
							parent
							
								
									6682d1b635
								
							
						
					
					
						commit
						be8d9a8b6b
					
				
							
								
								
									
										5
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										5
									
								
								go.mod
								
								
								
								
							| 
						 | 
				
			
			@ -32,6 +32,8 @@ require (
 | 
			
		|||
	github.com/pierrec/lz4 v2.5.2+incompatible // indirect
 | 
			
		||||
	github.com/pkg/errors v0.9.1
 | 
			
		||||
	github.com/prometheus/client_golang v1.11.0
 | 
			
		||||
	github.com/quasilyte/go-ruleguard v0.2.1 // indirect
 | 
			
		||||
	github.com/sbinet/npyio v0.6.0
 | 
			
		||||
	github.com/shirou/gopsutil v3.21.8+incompatible
 | 
			
		||||
	github.com/spaolacci/murmur3 v1.1.0
 | 
			
		||||
	github.com/spf13/cast v1.3.1
 | 
			
		||||
| 
						 | 
				
			
			@ -59,4 +61,5 @@ replace (
 | 
			
		|||
	github.com/dgrijalva/jwt-go => github.com/golang-jwt/jwt v3.2.2+incompatible // Fix security alert for jwt-go 3.2.0
 | 
			
		||||
	github.com/keybase/go-keychain => github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4
 | 
			
		||||
	google.golang.org/grpc => google.golang.org/grpc v1.38.0
 | 
			
		||||
)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										6
									
								
								go.sum
								
								
								
								
							
							
						
						
									
										6
									
								
								go.sum
								
								
								
								
							| 
						 | 
				
			
			@ -110,6 +110,7 @@ github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJm
 | 
			
		|||
github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM=
 | 
			
		||||
github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b/go.mod h1:ac9efd0D1fsDb3EJvhqgXRbFx7bs2wqZ10HQPeU8U/Q=
 | 
			
		||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
 | 
			
		||||
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
 | 
			
		||||
github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ=
 | 
			
		||||
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
 | 
			
		||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
 | 
			
		||||
| 
						 | 
				
			
			@ -584,6 +585,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O
 | 
			
		|||
github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4=
 | 
			
		||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
 | 
			
		||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
 | 
			
		||||
github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo=
 | 
			
		||||
github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc=
 | 
			
		||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 | 
			
		||||
github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84=
 | 
			
		||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
 | 
			
		||||
| 
						 | 
				
			
			@ -598,6 +601,8 @@ github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfF
 | 
			
		|||
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
 | 
			
		||||
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
 | 
			
		||||
github.com/sanity-io/litter v1.2.0/go.mod h1:JF6pZUFgu2Q0sBZ+HSV35P8TVPI1TTzEwyu9FXAw2W4=
 | 
			
		||||
github.com/sbinet/npyio v0.6.0 h1:IyqqQIzRjDym9xnIXsToCKei/qCzxDP+Y74KoMlMgXo=
 | 
			
		||||
github.com/sbinet/npyio v0.6.0/go.mod h1:/q3BNr6dJOy+t6h7RZchTJ0nwRJO52mivaem29WE1j8=
 | 
			
		||||
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
 | 
			
		||||
github.com/shirou/gopsutil v3.21.8+incompatible h1:sh0foI8tMRlCidUJR+KzqWYWxrkuuPIGiO6Vp+KXdCU=
 | 
			
		||||
github.com/shirou/gopsutil v3.21.8+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
 | 
			
		||||
| 
						 | 
				
			
			@ -1038,6 +1043,7 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY
 | 
			
		|||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 | 
			
		||||
golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
 | 
			
		||||
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
 | 
			
		||||
golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
 | 
			
		||||
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
 | 
			
		||||
golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE=
 | 
			
		||||
golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,6 +7,7 @@ import (
 | 
			
		|||
	"os"
 | 
			
		||||
	"path"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/allocator"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/common"
 | 
			
		||||
| 
						 | 
				
			
			@ -89,6 +90,13 @@ func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldD
 | 
			
		|||
	log.Debug(msg, stats...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getFileNameAndExt(filePath string) (string, string) {
 | 
			
		||||
	fileName := path.Base(filePath)
 | 
			
		||||
	fileType := path.Ext(fileName)
 | 
			
		||||
	fileNameWithoutExt := strings.TrimSuffix(fileName, fileType)
 | 
			
		||||
	return fileNameWithoutExt, fileType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// import process entry
 | 
			
		||||
// filePath and rowBased are from ImportTask
 | 
			
		||||
// if onlyValidate is true, this process only do validation, no data generated, callFlushFunc will not be called
 | 
			
		||||
| 
						 | 
				
			
			@ -99,8 +107,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
 | 
			
		|||
		// according to shard number, so the callFlushFunc will be called in the JSONRowConsumer
 | 
			
		||||
		for i := 0; i < len(filePaths); i++ {
 | 
			
		||||
			filePath := filePaths[i]
 | 
			
		||||
			fileName := path.Base(filePath)
 | 
			
		||||
			fileType := path.Ext(fileName)
 | 
			
		||||
			_, fileType := getFileNameAndExt(filePath)
 | 
			
		||||
			log.Debug("imprort wrapper:  row-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
 | 
			
		||||
 | 
			
		||||
			if fileType == JSONFileExt {
 | 
			
		||||
| 
						 | 
				
			
			@ -183,8 +190,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
 | 
			
		|||
		// parse/validate/consume data
 | 
			
		||||
		for i := 0; i < len(filePaths); i++ {
 | 
			
		||||
			filePath := filePaths[i]
 | 
			
		||||
			fileName := path.Base(filePath)
 | 
			
		||||
			fileType := path.Ext(fileName)
 | 
			
		||||
			fileName, fileType := getFileNameAndExt(filePath)
 | 
			
		||||
			log.Debug("imprort wrapper:  column-based file ", zap.Any("filePath", filePath), zap.Any("fileType", fileType))
 | 
			
		||||
 | 
			
		||||
			if fileType == JSONFileExt {
 | 
			
		||||
| 
						 | 
				
			
			@ -218,7 +224,28 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b
 | 
			
		|||
					return err
 | 
			
		||||
				}
 | 
			
		||||
			} else if fileType == NumpyFileExt {
 | 
			
		||||
				file, err := os.Open(filePath)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
				defer file.Close()
 | 
			
		||||
 | 
			
		||||
				// the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine
 | 
			
		||||
				flushFunc := func(field storage.FieldData) error {
 | 
			
		||||
					fields := make(map[string]storage.FieldData)
 | 
			
		||||
					fields[fileName] = field
 | 
			
		||||
					combineFunc(fields)
 | 
			
		||||
					return nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for numpy file, we say the file name(without extension) is the filed name
 | 
			
		||||
				parser := NewNumpyParser(p.ctx, p.collectionSchema, flushFunc)
 | 
			
		||||
				err = parser.Parse(file, fileName, onlyValidate)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath))
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -126,7 +126,7 @@ func Test_ImportRowBased(t *testing.T) {
 | 
			
		|||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_ImportColumnBased(t *testing.T) {
 | 
			
		||||
func Test_ImportColumnBased_json(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	err := os.MkdirAll(TempFilesPath, os.ModePerm)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
| 
						 | 
				
			
			@ -208,6 +208,88 @@ func Test_ImportColumnBased(t *testing.T) {
 | 
			
		|||
	assert.NotNil(t, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_ImportColumnBased_numpy(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	err := os.MkdirAll(TempFilesPath, os.ModePerm)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	defer os.RemoveAll(TempFilesPath)
 | 
			
		||||
 | 
			
		||||
	idAllocator := newIDAllocator(ctx, t)
 | 
			
		||||
 | 
			
		||||
	content := []byte(`{
 | 
			
		||||
		"field_bool": [true, false, true, true, true],
 | 
			
		||||
		"field_int8": [10, 11, 12, 13, 14],
 | 
			
		||||
		"field_int16": [100, 101, 102, 103, 104],
 | 
			
		||||
		"field_int32": [1000, 1001, 1002, 1003, 1004],
 | 
			
		||||
		"field_int64": [10000, 10001, 10002, 10003, 10004],
 | 
			
		||||
		"field_float": [3.14, 3.15, 3.16, 3.17, 3.18],
 | 
			
		||||
		"field_double": [5.1, 5.2, 5.3, 5.4, 5.5],
 | 
			
		||||
		"field_string": ["a", "b", "c", "d", "e"]
 | 
			
		||||
	}`)
 | 
			
		||||
 | 
			
		||||
	files := make([]string, 0)
 | 
			
		||||
 | 
			
		||||
	filePath := TempFilesPath + "scalar_fields.json"
 | 
			
		||||
	fp1 := saveFile(t, filePath, content)
 | 
			
		||||
	fp1.Close()
 | 
			
		||||
	files = append(files, filePath)
 | 
			
		||||
 | 
			
		||||
	filePath = TempFilesPath + "field_binary_vector.npy"
 | 
			
		||||
	bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}}
 | 
			
		||||
	err = CreateNumpyFile(filePath, bin)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	files = append(files, filePath)
 | 
			
		||||
 | 
			
		||||
	filePath = TempFilesPath + "field_float_vector.npy"
 | 
			
		||||
	flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}}
 | 
			
		||||
	err = CreateNumpyFile(filePath, flo)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	files = append(files, filePath)
 | 
			
		||||
 | 
			
		||||
	rowCount := 0
 | 
			
		||||
	flushFunc := func(fields map[string]storage.FieldData) error {
 | 
			
		||||
		count := 0
 | 
			
		||||
		for _, data := range fields {
 | 
			
		||||
			assert.Less(t, 0, data.RowNum())
 | 
			
		||||
			if count == 0 {
 | 
			
		||||
				count = data.RowNum()
 | 
			
		||||
			} else {
 | 
			
		||||
				assert.Equal(t, count, data.RowNum())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		rowCount += count
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// success case
 | 
			
		||||
	wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
 | 
			
		||||
 | 
			
		||||
	err = wrapper.Import(files, false, false)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	assert.Equal(t, 5, rowCount)
 | 
			
		||||
 | 
			
		||||
	// parse error
 | 
			
		||||
	content = []byte(`{
 | 
			
		||||
		"field_bool": [true, false, true, true, true]
 | 
			
		||||
	}`)
 | 
			
		||||
 | 
			
		||||
	filePath = TempFilesPath + "rows_2.json"
 | 
			
		||||
	fp2 := saveFile(t, filePath, content)
 | 
			
		||||
	defer fp2.Close()
 | 
			
		||||
 | 
			
		||||
	wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc)
 | 
			
		||||
	files = make([]string, 0)
 | 
			
		||||
	files = append(files, filePath)
 | 
			
		||||
	err = wrapper.Import(files, false, false)
 | 
			
		||||
	assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
	// file doesn't exist
 | 
			
		||||
	files = make([]string, 0)
 | 
			
		||||
	files = append(files, "/dummy/dummy.json")
 | 
			
		||||
	err = wrapper.Import(files, false, false)
 | 
			
		||||
	assert.NotNil(t, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func perfSchema(dim int) *schemapb.CollectionSchema {
 | 
			
		||||
	schema := &schemapb.CollectionSchema{
 | 
			
		||||
		Name:        "schema",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,7 +24,7 @@ type JSONParser struct {
 | 
			
		|||
	fields  map[string]int64 // fields need to be parsed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newImportManager helper function to create a importManager
 | 
			
		||||
// NewJSONParser helper function to create a JSONParser
 | 
			
		||||
func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser {
 | 
			
		||||
	fields := make(map[string]int64)
 | 
			
		||||
	for i := 0; i < len(collectionSchema.Fields); i++ {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,356 @@
 | 
			
		|||
package importutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"github.com/sbinet/npyio"
 | 
			
		||||
	"github.com/sbinet/npyio/npy"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CreateNumpyFile(path string, data interface{}) error {
 | 
			
		||||
	f, err := os.Create(path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer f.Close()
 | 
			
		||||
 | 
			
		||||
	err = npyio.Write(f, data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// a class to expand other numpy lib ability
 | 
			
		||||
// we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio
 | 
			
		||||
// the npyio lib read data one by one, the performance is poor, we expand the read methods
 | 
			
		||||
// to read data in one batch, the performance is 100X faster
 | 
			
		||||
// the gonpy lib also read data in one batch, but it has no method to read bool data, and the ability
 | 
			
		||||
// to handle different data type is not strong as the npylib, so we choose the npyio lib to expand.
 | 
			
		||||
type NumpyAdapter struct {
 | 
			
		||||
	reader       io.Reader        // data source, typically is os.File
 | 
			
		||||
	npyReader    *npy.Reader      // reader of npyio lib
 | 
			
		||||
	order        binary.ByteOrder // LittleEndian or BigEndian
 | 
			
		||||
	readPosition int              // how many elements have been read
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewNumpyAdapter(reader io.Reader) (*NumpyAdapter, error) {
 | 
			
		||||
	r, err := npyio.NewReader(reader)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	adapter := &NumpyAdapter{
 | 
			
		||||
		reader:       reader,
 | 
			
		||||
		npyReader:    r,
 | 
			
		||||
		readPosition: 0,
 | 
			
		||||
	}
 | 
			
		||||
	adapter.setByteOrder()
 | 
			
		||||
 | 
			
		||||
	return adapter, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// the logic of this method is copied from npyio lib
 | 
			
		||||
func (n *NumpyAdapter) setByteOrder() {
 | 
			
		||||
	var nativeEndian binary.ByteOrder
 | 
			
		||||
	v := uint16(1)
 | 
			
		||||
	switch byte(v >> 8) {
 | 
			
		||||
	case 0:
 | 
			
		||||
		nativeEndian = binary.LittleEndian
 | 
			
		||||
	case 1:
 | 
			
		||||
		nativeEndian = binary.BigEndian
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type[0] {
 | 
			
		||||
	case '<':
 | 
			
		||||
		n.order = binary.LittleEndian
 | 
			
		||||
	case '>':
 | 
			
		||||
		n.order = binary.BigEndian
 | 
			
		||||
	default:
 | 
			
		||||
		n.order = nativeEndian
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) Reader() io.Reader {
 | 
			
		||||
	return n.reader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) NpyReader() *npy.Reader {
 | 
			
		||||
	return n.npyReader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) GetType() string {
 | 
			
		||||
	return n.npyReader.Header.Descr.Type
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) GetShape() []int {
 | 
			
		||||
	return n.npyReader.Header.Descr.Shape
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) checkSize(size int) int {
 | 
			
		||||
	shape := n.GetShape()
 | 
			
		||||
 | 
			
		||||
	// empty file?
 | 
			
		||||
	if len(shape) == 0 {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	total := 1
 | 
			
		||||
	for i := 0; i < len(shape); i++ {
 | 
			
		||||
		total *= shape[i]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if total == 0 {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// overflow?
 | 
			
		||||
	if size > (total - n.readPosition) {
 | 
			
		||||
		return total - n.readPosition
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return size
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadBool(size int) ([]bool, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "b1", "<b1", "|b1", "bool":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not bool type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]bool, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadUint8(size int) ([]uint8, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "u1", "<u1", "|u1", "uint8":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not uint8 type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]uint8, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadInt8(size int) ([]int8, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "i1", "<i1", "|i1", ">i1", "int8":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not int8 type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]int8, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadInt16(size int) ([]int16, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "i2", "<i2", "|i2", ">i2", "int16":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not int16 type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]int16, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadInt32(size int) ([]int32, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "i4", "<i4", "|i4", ">i4", "int32":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not int32 type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]int32, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadInt64(size int) ([]int64, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "i8", "<i8", "|i8", ">i8", "int64":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not int64 type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]int64, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadFloat32(size int) ([]float32, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "f4", "<f4", "|f4", ">f4", "float32":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not float32 type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]float32, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NumpyAdapter) ReadFloat64(size int) ([]float64, error) {
 | 
			
		||||
	if n.npyReader == nil {
 | 
			
		||||
		return nil, errors.New("reader is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// incorrect type
 | 
			
		||||
	switch n.npyReader.Header.Descr.Type {
 | 
			
		||||
	case "f8", "<f8", "|f8", ">f8", "float64":
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("numpy data is not float32 type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// avoid read overflow
 | 
			
		||||
	readSize := n.checkSize(size)
 | 
			
		||||
	if readSize <= 0 {
 | 
			
		||||
		return nil, errors.New("nothing to read")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data := make([]float64, readSize)
 | 
			
		||||
	err := binary.Read(n.reader, n.order, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update read position after successfully read
 | 
			
		||||
	n.readPosition += readSize
 | 
			
		||||
 | 
			
		||||
	return data, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,455 @@
 | 
			
		|||
package importutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/sbinet/npyio/npy"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MockReader struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *MockReader) Read(p []byte) (n int, err error) {
 | 
			
		||||
	return 0, io.EOF
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_CreateNumpyFile(t *testing.T) {
 | 
			
		||||
	// directory doesn't exist
 | 
			
		||||
	data1 := []float32{1, 2, 3, 4, 5}
 | 
			
		||||
	err := CreateNumpyFile("/dummy_not_exist/dummy.npy", data1)
 | 
			
		||||
	assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
	// invalid data type
 | 
			
		||||
	data2 := make(map[string]int)
 | 
			
		||||
	err = CreateNumpyFile("/tmp/dummy.npy", data2)
 | 
			
		||||
	assert.NotNil(t, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_SetByteOrder(t *testing.T) {
 | 
			
		||||
	adapter := &NumpyAdapter{
 | 
			
		||||
		reader:    nil,
 | 
			
		||||
		npyReader: &npy.Reader{},
 | 
			
		||||
	}
 | 
			
		||||
	assert.Nil(t, adapter.Reader())
 | 
			
		||||
	assert.NotNil(t, adapter.NpyReader())
 | 
			
		||||
 | 
			
		||||
	adapter.npyReader.Header.Descr.Type = "<i8"
 | 
			
		||||
	adapter.setByteOrder()
 | 
			
		||||
	assert.Equal(t, binary.LittleEndian, adapter.order)
 | 
			
		||||
 | 
			
		||||
	adapter.npyReader.Header.Descr.Type = ">i8"
 | 
			
		||||
	adapter.setByteOrder()
 | 
			
		||||
	assert.Equal(t, binary.BigEndian, adapter.order)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_ReadError(t *testing.T) {
 | 
			
		||||
	adapter := &NumpyAdapter{
 | 
			
		||||
		reader:    nil,
 | 
			
		||||
		npyReader: nil,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// reader is nil
 | 
			
		||||
	{
 | 
			
		||||
		_, err := adapter.ReadBool(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		_, err = adapter.ReadUint8(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		_, err = adapter.ReadInt8(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		_, err = adapter.ReadInt16(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		_, err = adapter.ReadInt32(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		_, err = adapter.ReadInt64(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		_, err = adapter.ReadFloat32(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		_, err = adapter.ReadFloat64(1)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	adapter = &NumpyAdapter{
 | 
			
		||||
		reader:    &MockReader{},
 | 
			
		||||
		npyReader: &npy.Reader{},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "bool"
 | 
			
		||||
		data, err := adapter.ReadBool(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadBool(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "u1"
 | 
			
		||||
		data, err := adapter.ReadUint8(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadUint8(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "i1"
 | 
			
		||||
		data, err := adapter.ReadInt8(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadInt8(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "i2"
 | 
			
		||||
		data, err := adapter.ReadInt16(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadInt16(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "i4"
 | 
			
		||||
		data, err := adapter.ReadInt32(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadInt32(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "i8"
 | 
			
		||||
		data, err := adapter.ReadInt64(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadInt64(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "f4"
 | 
			
		||||
		data, err := adapter.ReadFloat32(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadFloat32(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "f8"
 | 
			
		||||
		data, err := adapter.ReadFloat64(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		adapter.npyReader.Header.Descr.Type = "dummy"
 | 
			
		||||
		data, err = adapter.ReadFloat64(1)
 | 
			
		||||
		assert.Nil(t, data)
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_Read(t *testing.T) {
 | 
			
		||||
	err := os.MkdirAll(TempFilesPath, os.ModePerm)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	defer os.RemoveAll(TempFilesPath)
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "bool.npy"
 | 
			
		||||
		data := []bool{true, false, true, false}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadBool(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadBool(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadBool(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
 | 
			
		||||
		// incorrect type read
 | 
			
		||||
		resu1, err := adapter.ReadUint8(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resu1)
 | 
			
		||||
 | 
			
		||||
		resi1, err := adapter.ReadInt8(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resi1)
 | 
			
		||||
 | 
			
		||||
		resi2, err := adapter.ReadInt16(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resi2)
 | 
			
		||||
 | 
			
		||||
		resi4, err := adapter.ReadInt32(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resi4)
 | 
			
		||||
 | 
			
		||||
		resi8, err := adapter.ReadInt64(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resi8)
 | 
			
		||||
 | 
			
		||||
		resf4, err := adapter.ReadFloat32(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resf4)
 | 
			
		||||
 | 
			
		||||
		resf8, err := adapter.ReadFloat64(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resf8)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "uint8.npy"
 | 
			
		||||
		data := []uint8{1, 2, 3, 4, 5, 6}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadUint8(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadUint8(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadUint8(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
 | 
			
		||||
		// incorrect type read
 | 
			
		||||
		resb, err := adapter.ReadBool(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, resb)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "int8.npy"
 | 
			
		||||
		data := []int8{1, 2, 3, 4, 5, 6}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadInt8(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt8(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt8(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "int16.npy"
 | 
			
		||||
		data := []int16{1, 2, 3, 4, 5, 6}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadInt16(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt16(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt16(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "int32.npy"
 | 
			
		||||
		data := []int32{1, 2, 3, 4, 5, 6}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadInt32(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt32(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt32(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "int64.npy"
 | 
			
		||||
		data := []int64{1, 2, 3, 4, 5, 6}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadInt64(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt64(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadInt64(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "float.npy"
 | 
			
		||||
		data := []float32{1, 2, 3, 4, 5, 6}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadFloat32(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadFloat32(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadFloat32(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		filePath := TempFilesPath + "double.npy"
 | 
			
		||||
		data := []float64{1, 2, 3, 4, 5, 6}
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		file, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		res, err := adapter.ReadFloat64(len(data) - 1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data)-1, len(res))
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(res); i++ {
 | 
			
		||||
			assert.Equal(t, data[i], res[i])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadFloat64(len(data))
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, 1, len(res))
 | 
			
		||||
		assert.Equal(t, data[len(data)-1], res[0])
 | 
			
		||||
 | 
			
		||||
		res, err = adapter.ReadFloat64(len(data))
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		assert.Nil(t, res)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,290 @@
 | 
			
		|||
package importutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/log"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/proto/schemapb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ColumnDesc struct {
 | 
			
		||||
	name         string            // name of the target column
 | 
			
		||||
	dt           schemapb.DataType // data type of the target column
 | 
			
		||||
	elementCount int               // how many elements need to be read
 | 
			
		||||
	dimension    int               // only for vector
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NumpyParser struct {
 | 
			
		||||
	ctx              context.Context            // for canceling parse process
 | 
			
		||||
	collectionSchema *schemapb.CollectionSchema // collection schema
 | 
			
		||||
	columnDesc       *ColumnDesc                // description for target column
 | 
			
		||||
 | 
			
		||||
	columnData    storage.FieldData                   // in-memory column data
 | 
			
		||||
	callFlushFunc func(field storage.FieldData) error // call back function to output column data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewNumpyParser helper function to create a NumpyParser
 | 
			
		||||
func NewNumpyParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema,
 | 
			
		||||
	flushFunc func(field storage.FieldData) error) *NumpyParser {
 | 
			
		||||
	if collectionSchema == nil || flushFunc == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	parser := &NumpyParser{
 | 
			
		||||
		ctx:              ctx,
 | 
			
		||||
		collectionSchema: collectionSchema,
 | 
			
		||||
		columnDesc:       &ColumnDesc{},
 | 
			
		||||
		callFlushFunc:    flushFunc,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return parser
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *NumpyParser) logError(msg string) error {
 | 
			
		||||
	log.Error(msg)
 | 
			
		||||
	return errors.New(msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// data type converted from numpy header description, for vector field, the type is int8(binary vector) or float32(float vector)
 | 
			
		||||
func convertNumpyType(str string) (schemapb.DataType, error) {
 | 
			
		||||
	switch str {
 | 
			
		||||
	case "b1", "<b1", "|b1", "bool":
 | 
			
		||||
		return schemapb.DataType_Bool, nil
 | 
			
		||||
	case "u1", "<u1", "|u1", "uint8": // binary vector data type is uint8
 | 
			
		||||
		return schemapb.DataType_BinaryVector, nil
 | 
			
		||||
	case "i1", "<i1", "|i1", ">i1", "int8":
 | 
			
		||||
		return schemapb.DataType_Int8, nil
 | 
			
		||||
	case "i2", "<i2", "|i2", ">i2", "int16":
 | 
			
		||||
		return schemapb.DataType_Int16, nil
 | 
			
		||||
	case "i4", "<i4", "|i4", ">i4", "int32":
 | 
			
		||||
		return schemapb.DataType_Int32, nil
 | 
			
		||||
	case "i8", "<i8", "|i8", ">i8", "int64":
 | 
			
		||||
		return schemapb.DataType_Int64, nil
 | 
			
		||||
	case "f4", "<f4", "|f4", ">f4", "float32":
 | 
			
		||||
		return schemapb.DataType_Float, nil
 | 
			
		||||
	case "f8", "<f8", "|f8", ">f8", "float64":
 | 
			
		||||
		return schemapb.DataType_Double, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return schemapb.DataType_None, errors.New("unsupported data type " + str)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *NumpyParser) validate(adapter *NumpyAdapter, fieldName string) error {
 | 
			
		||||
	if adapter == nil {
 | 
			
		||||
		return errors.New("numpy adapter is nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check existence of the target field
 | 
			
		||||
	var schema *schemapb.FieldSchema
 | 
			
		||||
	for i := 0; i < len(p.collectionSchema.Fields); i++ {
 | 
			
		||||
		schema = p.collectionSchema.Fields[i]
 | 
			
		||||
		if schema.GetName() == fieldName {
 | 
			
		||||
			p.columnDesc.name = fieldName
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if p.columnDesc.name == "" {
 | 
			
		||||
		return errors.New("the field " + fieldName + " doesn't exist")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.columnDesc.dt = schema.DataType
 | 
			
		||||
	elementType, err := convertNumpyType(adapter.GetType())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	shape := adapter.GetShape()
 | 
			
		||||
 | 
			
		||||
	// 1. field data type should be consist to numpy data type
 | 
			
		||||
	// 2. vector field dimension should be consist to numpy shape
 | 
			
		||||
	if schemapb.DataType_FloatVector == schema.DataType {
 | 
			
		||||
		if elementType != schemapb.DataType_Float {
 | 
			
		||||
			return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// vector field, the shape should be 2
 | 
			
		||||
		if len(shape) != 2 {
 | 
			
		||||
			return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// shape[0] is row count, shape[1] is element count per row
 | 
			
		||||
		p.columnDesc.elementCount = shape[0] * shape[1]
 | 
			
		||||
 | 
			
		||||
		p.columnDesc.dimension, err = getFieldDimension(schema)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if shape[1] != p.columnDesc.dimension {
 | 
			
		||||
			return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
 | 
			
		||||
		}
 | 
			
		||||
	} else if schemapb.DataType_BinaryVector == schema.DataType {
 | 
			
		||||
		if elementType != schemapb.DataType_BinaryVector {
 | 
			
		||||
			return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// vector field, the shape should be 2
 | 
			
		||||
		if len(shape) != 2 {
 | 
			
		||||
			return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// shape[0] is row count, shape[1] is element count per row
 | 
			
		||||
		p.columnDesc.elementCount = shape[0] * shape[1]
 | 
			
		||||
 | 
			
		||||
		p.columnDesc.dimension, err = getFieldDimension(schema)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if shape[1] != p.columnDesc.dimension/8 {
 | 
			
		||||
			return errors.New("illegal row width " + strconv.Itoa(shape[1]) + " for field " + schema.GetName() + " dimension " + strconv.Itoa(p.columnDesc.dimension))
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if elementType != schema.DataType {
 | 
			
		||||
			return errors.New("illegal data type " + adapter.GetType() + " for field " + schema.GetName())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// scalar field, the shape should be 1
 | 
			
		||||
		if len(shape) != 1 {
 | 
			
		||||
			return errors.New("illegal numpy shape " + strconv.Itoa(len(shape)) + " for field " + schema.GetName())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnDesc.elementCount = shape[0]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// this method read numpy data section into a storage.FieldData
 | 
			
		||||
// please note it will require a large memory block(the memory size is almost equal to numpy file size)
 | 
			
		||||
func (p *NumpyParser) consume(adapter *NumpyAdapter) error {
 | 
			
		||||
	switch p.columnDesc.dt {
 | 
			
		||||
	case schemapb.DataType_Bool:
 | 
			
		||||
		data, err := adapter.ReadBool(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.BoolFieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	case schemapb.DataType_Int8:
 | 
			
		||||
		data, err := adapter.ReadInt8(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.Int8FieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
		}
 | 
			
		||||
	case schemapb.DataType_Int16:
 | 
			
		||||
		data, err := adapter.ReadInt16(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.Int16FieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
		}
 | 
			
		||||
	case schemapb.DataType_Int32:
 | 
			
		||||
		data, err := adapter.ReadInt32(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.Int32FieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
		}
 | 
			
		||||
	case schemapb.DataType_Int64:
 | 
			
		||||
		data, err := adapter.ReadInt64(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.Int64FieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
		}
 | 
			
		||||
	case schemapb.DataType_Float:
 | 
			
		||||
		data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.FloatFieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
		}
 | 
			
		||||
	case schemapb.DataType_Double:
 | 
			
		||||
		data, err := adapter.ReadFloat64(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.DoubleFieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
		}
 | 
			
		||||
	case schemapb.DataType_BinaryVector:
 | 
			
		||||
		data, err := adapter.ReadUint8(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.BinaryVectorFieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
			Dim:     p.columnDesc.dimension,
 | 
			
		||||
		}
 | 
			
		||||
	case schemapb.DataType_FloatVector:
 | 
			
		||||
		data, err := adapter.ReadFloat32(p.columnDesc.elementCount)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		p.columnData = &storage.FloatVectorFieldData{
 | 
			
		||||
			NumRows: []int64{int64(p.columnDesc.elementCount)},
 | 
			
		||||
			Data:    data,
 | 
			
		||||
			Dim:     p.columnDesc.dimension,
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("unsupported data type: " + strconv.Itoa(int(p.columnDesc.dt)))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *NumpyParser) Parse(reader io.Reader, fieldName string, onlyValidate bool) error {
 | 
			
		||||
	adapter, err := NewNumpyAdapter(reader)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return p.logError("Numpy parse: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// the validation method only check the file header information
 | 
			
		||||
	err = p.validate(adapter, fieldName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return p.logError("Numpy parse: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if onlyValidate {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// read all data from the numpy file
 | 
			
		||||
	err = p.consume(adapter)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return p.logError("Numpy parse: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return p.callFlushFunc(p.columnData)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,509 @@
 | 
			
		|||
package importutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/proto/schemapb"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/storage"
 | 
			
		||||
	"github.com/milvus-io/milvus/internal/util/timerecord"
 | 
			
		||||
	"github.com/sbinet/npyio/npy"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_NewNumpyParser(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
 | 
			
		||||
	parser := NewNumpyParser(ctx, nil, nil)
 | 
			
		||||
	assert.Nil(t, parser)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_ConvertNumpyType(t *testing.T) {
 | 
			
		||||
	checkFunc := func(inputs []string, output schemapb.DataType) {
 | 
			
		||||
		for i := 0; i < len(inputs); i++ {
 | 
			
		||||
			dt, err := convertNumpyType(inputs[i])
 | 
			
		||||
			assert.Nil(t, err)
 | 
			
		||||
			assert.Equal(t, output, dt)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
 | 
			
		||||
	checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
 | 
			
		||||
	checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
 | 
			
		||||
	checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
 | 
			
		||||
	checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
 | 
			
		||||
	checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
 | 
			
		||||
	checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
 | 
			
		||||
 | 
			
		||||
	dt, err := convertNumpyType("dummy")
 | 
			
		||||
	assert.NotNil(t, err)
 | 
			
		||||
	assert.Equal(t, schemapb.DataType_None, dt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_Validate(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	err := os.MkdirAll(TempFilesPath, os.ModePerm)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	defer os.RemoveAll(TempFilesPath)
 | 
			
		||||
 | 
			
		||||
	schema := sampleSchema()
 | 
			
		||||
	flushFunc := func(field storage.FieldData) error {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	adapter := &NumpyAdapter{npyReader: &npy.Reader{}}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		// string type is not supported
 | 
			
		||||
		p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
 | 
			
		||||
			Fields: []*schemapb.FieldSchema{
 | 
			
		||||
				{
 | 
			
		||||
					FieldID:      109,
 | 
			
		||||
					Name:         "field_string",
 | 
			
		||||
					IsPrimaryKey: false,
 | 
			
		||||
					Description:  "string",
 | 
			
		||||
					DataType:     schemapb.DataType_String,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		}, flushFunc)
 | 
			
		||||
		err = p.validate(adapter, "dummy")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
		err = p.validate(adapter, "field_string")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// reader is nil
 | 
			
		||||
	parser := NewNumpyParser(ctx, schema, flushFunc)
 | 
			
		||||
	err = parser.validate(nil, "")
 | 
			
		||||
	assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
	// validate scalar data
 | 
			
		||||
	func() {
 | 
			
		||||
		filePath := TempFilesPath + "scalar_1.npy"
 | 
			
		||||
		data1 := []float64{0, 1, 2, 3, 4, 5}
 | 
			
		||||
		CreateNumpyFile(filePath, data1)
 | 
			
		||||
 | 
			
		||||
		file1, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file1.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_double")
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data1), parser.columnDesc.elementCount)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// data type mismatch
 | 
			
		||||
		filePath = TempFilesPath + "scalar_2.npy"
 | 
			
		||||
		data2 := []int64{0, 1, 2, 3, 4, 5}
 | 
			
		||||
		CreateNumpyFile(filePath, data2)
 | 
			
		||||
 | 
			
		||||
		file2, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file2.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file2)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_double")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// shape mismatch
 | 
			
		||||
		filePath = TempFilesPath + "scalar_2.npy"
 | 
			
		||||
		data3 := [][2]float64{{1, 1}}
 | 
			
		||||
		CreateNumpyFile(filePath, data3)
 | 
			
		||||
 | 
			
		||||
		file3, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file2.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file3)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_double")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// validate binary vector data
 | 
			
		||||
	func() {
 | 
			
		||||
		filePath := TempFilesPath + "binary_vector_1.npy"
 | 
			
		||||
		data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}}
 | 
			
		||||
		CreateNumpyFile(filePath, data1)
 | 
			
		||||
 | 
			
		||||
		file1, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file1.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_binary_vector")
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
 | 
			
		||||
 | 
			
		||||
		// data type mismatch
 | 
			
		||||
		filePath = TempFilesPath + "binary_vector_2.npy"
 | 
			
		||||
		data2 := [][2]uint16{{0, 1}, {2, 3}, {4, 5}}
 | 
			
		||||
		CreateNumpyFile(filePath, data2)
 | 
			
		||||
 | 
			
		||||
		file2, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file2.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file2)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_binary_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// shape mismatch
 | 
			
		||||
		filePath = TempFilesPath + "binary_vector_3.npy"
 | 
			
		||||
		data3 := []uint8{1, 2, 3}
 | 
			
		||||
		CreateNumpyFile(filePath, data3)
 | 
			
		||||
 | 
			
		||||
		file3, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file3.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file3)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_binary_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// shape[1] mismatch
 | 
			
		||||
		filePath = TempFilesPath + "binary_vector_4.npy"
 | 
			
		||||
		data4 := [][3]uint8{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}
 | 
			
		||||
		CreateNumpyFile(filePath, data4)
 | 
			
		||||
 | 
			
		||||
		file4, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file4.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file4)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_binary_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// dimension mismatch
 | 
			
		||||
		p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
 | 
			
		||||
			Fields: []*schemapb.FieldSchema{
 | 
			
		||||
				{
 | 
			
		||||
					FieldID:  109,
 | 
			
		||||
					Name:     "field_binary_vector",
 | 
			
		||||
					DataType: schemapb.DataType_BinaryVector,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		}, flushFunc)
 | 
			
		||||
 | 
			
		||||
		err = p.validate(adapter, "field_binary_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// validate float vector data
 | 
			
		||||
	func() {
 | 
			
		||||
		filePath := TempFilesPath + "float_vector.npy"
 | 
			
		||||
		data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}}
 | 
			
		||||
		CreateNumpyFile(filePath, data1)
 | 
			
		||||
 | 
			
		||||
		file1, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file1.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err := NewNumpyAdapter(file1)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_float_vector")
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
 | 
			
		||||
 | 
			
		||||
		// data type mismatch
 | 
			
		||||
		filePath = TempFilesPath + "float_vector_2.npy"
 | 
			
		||||
		data2 := [][4]int32{{0, 1, 2, 3}}
 | 
			
		||||
		CreateNumpyFile(filePath, data2)
 | 
			
		||||
 | 
			
		||||
		file2, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file2.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file2)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_float_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// shape mismatch
 | 
			
		||||
		filePath = TempFilesPath + "float_vector_3.npy"
 | 
			
		||||
		data3 := []float32{1, 2, 3}
 | 
			
		||||
		CreateNumpyFile(filePath, data3)
 | 
			
		||||
 | 
			
		||||
		file3, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file3.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file3)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_float_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// shape[1] mismatch
 | 
			
		||||
		filePath = TempFilesPath + "float_vector_4.npy"
 | 
			
		||||
		data4 := [][3]float32{{0, 0, 0}, {1, 1, 1}}
 | 
			
		||||
		CreateNumpyFile(filePath, data4)
 | 
			
		||||
 | 
			
		||||
		file4, err := os.Open(filePath)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
		defer file4.Close()
 | 
			
		||||
 | 
			
		||||
		adapter, err = NewNumpyAdapter(file4)
 | 
			
		||||
		assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
		err = parser.validate(adapter, "field_float_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
 | 
			
		||||
		// dimension mismatch
 | 
			
		||||
		p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
 | 
			
		||||
			Fields: []*schemapb.FieldSchema{
 | 
			
		||||
				{
 | 
			
		||||
					FieldID:  109,
 | 
			
		||||
					Name:     "field_float_vector",
 | 
			
		||||
					DataType: schemapb.DataType_FloatVector,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		}, flushFunc)
 | 
			
		||||
 | 
			
		||||
		err = p.validate(adapter, "field_float_vector")
 | 
			
		||||
		assert.NotNil(t, err)
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_Parse(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	err := os.MkdirAll(TempFilesPath, os.ModePerm)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	defer os.RemoveAll(TempFilesPath)
 | 
			
		||||
 | 
			
		||||
	schema := sampleSchema()
 | 
			
		||||
 | 
			
		||||
	checkFunc := func(data interface{}, fieldName string, callback func(field storage.FieldData) error) {
 | 
			
		||||
 | 
			
		||||
		filePath := TempFilesPath + fieldName + ".npy"
 | 
			
		||||
		CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
		func() {
 | 
			
		||||
			file, err := os.Open(filePath)
 | 
			
		||||
			assert.Nil(t, err)
 | 
			
		||||
			defer file.Close()
 | 
			
		||||
 | 
			
		||||
			parser := NewNumpyParser(ctx, schema, callback)
 | 
			
		||||
			err = parser.Parse(file, fieldName, false)
 | 
			
		||||
			assert.Nil(t, err)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		// validation failed
 | 
			
		||||
		func() {
 | 
			
		||||
			file, err := os.Open(filePath)
 | 
			
		||||
			assert.Nil(t, err)
 | 
			
		||||
			defer file.Close()
 | 
			
		||||
 | 
			
		||||
			parser := NewNumpyParser(ctx, schema, callback)
 | 
			
		||||
			err = parser.Parse(file, "dummy", false)
 | 
			
		||||
			assert.NotNil(t, err)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		// read data error
 | 
			
		||||
		func() {
 | 
			
		||||
			parser := NewNumpyParser(ctx, schema, callback)
 | 
			
		||||
			err = parser.Parse(&MockReader{}, fieldName, false)
 | 
			
		||||
			assert.NotNil(t, err)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// scalar bool
 | 
			
		||||
	data1 := []bool{true, false, true, false, true}
 | 
			
		||||
	flushFunc := func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data1), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data1); i++ {
 | 
			
		||||
			assert.Equal(t, data1[i], field.GetRow(i))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data1, "field_bool", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// scalar int8
 | 
			
		||||
	data2 := []int8{1, 2, 3, 4, 5}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data2), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data2); i++ {
 | 
			
		||||
			assert.Equal(t, data2[i], field.GetRow(i))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data2, "field_int8", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// scalar int16
 | 
			
		||||
	data3 := []int16{1, 2, 3, 4, 5}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data3), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data3); i++ {
 | 
			
		||||
			assert.Equal(t, data3[i], field.GetRow(i))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data3, "field_int16", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// scalar int32
 | 
			
		||||
	data4 := []int32{1, 2, 3, 4, 5}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data4), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data4); i++ {
 | 
			
		||||
			assert.Equal(t, data4[i], field.GetRow(i))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data4, "field_int32", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// scalar int64
 | 
			
		||||
	data5 := []int64{1, 2, 3, 4, 5}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data5), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data5); i++ {
 | 
			
		||||
			assert.Equal(t, data5[i], field.GetRow(i))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data5, "field_int64", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// scalar float
 | 
			
		||||
	data6 := []float32{1, 2, 3, 4, 5}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data6), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data6); i++ {
 | 
			
		||||
			assert.Equal(t, data6[i], field.GetRow(i))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data6, "field_float", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// scalar double
 | 
			
		||||
	data7 := []float64{1, 2, 3, 4, 5}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data7), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data7); i++ {
 | 
			
		||||
			assert.Equal(t, data7[i], field.GetRow(i))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data7, "field_double", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// binary vector
 | 
			
		||||
	data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data8), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data8); i++ {
 | 
			
		||||
			row := field.GetRow(i).([]uint8)
 | 
			
		||||
			for k := 0; k < len(row); k++ {
 | 
			
		||||
				assert.Equal(t, data8[i][k], row[k])
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data8, "field_binary_vector", flushFunc)
 | 
			
		||||
 | 
			
		||||
	// double vector
 | 
			
		||||
	data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
 | 
			
		||||
	flushFunc = func(field storage.FieldData) error {
 | 
			
		||||
		assert.NotNil(t, field)
 | 
			
		||||
		assert.Equal(t, len(data9), field.RowNum())
 | 
			
		||||
 | 
			
		||||
		for i := 0; i < len(data9); i++ {
 | 
			
		||||
			row := field.GetRow(i).([]float32)
 | 
			
		||||
			for k := 0; k < len(row); k++ {
 | 
			
		||||
				assert.Equal(t, data9[i][k], row[k])
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	checkFunc(data9, "field_float_vector", flushFunc)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_Parse_perf(t *testing.T) {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	err := os.MkdirAll(TempFilesPath, os.ModePerm)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	defer os.RemoveAll(TempFilesPath)
 | 
			
		||||
 | 
			
		||||
	tr := timerecord.NewTimeRecorder("numpy parse performance")
 | 
			
		||||
 | 
			
		||||
	// change the parameter to test performance
 | 
			
		||||
	rowCount := 10000
 | 
			
		||||
	dotValue := float32(3.1415926)
 | 
			
		||||
	const (
 | 
			
		||||
		dim = 128
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	schema := perfSchema(dim)
 | 
			
		||||
 | 
			
		||||
	data := make([][dim]float32, 0)
 | 
			
		||||
	for i := 0; i < rowCount; i++ {
 | 
			
		||||
		var row [dim]float32
 | 
			
		||||
		for k := 0; k < dim; k++ {
 | 
			
		||||
			row[k] = float32(i) + dotValue
 | 
			
		||||
		}
 | 
			
		||||
		data = append(data, row)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tr.Record("generate large data")
 | 
			
		||||
 | 
			
		||||
	flushFunc := func(field storage.FieldData) error {
 | 
			
		||||
		assert.Equal(t, len(data), field.RowNum())
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	filePath := TempFilesPath + "perf.npy"
 | 
			
		||||
	CreateNumpyFile(filePath, data)
 | 
			
		||||
 | 
			
		||||
	tr.Record("generate large numpy file " + filePath)
 | 
			
		||||
 | 
			
		||||
	file, err := os.Open(filePath)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
 | 
			
		||||
	parser := NewNumpyParser(ctx, schema, flushFunc)
 | 
			
		||||
	err = parser.Parse(file, "Vector", false)
 | 
			
		||||
	assert.Nil(t, err)
 | 
			
		||||
 | 
			
		||||
	tr.Record("parse large numpy files: " + filePath)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -47,6 +47,7 @@ go test -race -cover "${MILVUS_DIR}/util/retry/..." -failfast
 | 
			
		|||
go test -race -cover "${MILVUS_DIR}/util/sessionutil/..." -failfast
 | 
			
		||||
go test -race -cover "${MILVUS_DIR}/util/trace/..." -failfast
 | 
			
		||||
go test -race -cover "${MILVUS_DIR}/util/typeutil/..." -failfast
 | 
			
		||||
go test -race -cover "${MILVUS_DIR}/util/importutil/..." -failfast
 | 
			
		||||
 | 
			
		||||
# TODO: remove to distributed
 | 
			
		||||
#go test -race -cover "${MILVUS_DIR}/proxy/..." -failfast
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue