milvus/internal/util/importutil/numpy_parser_test.go

527 lines
13 KiB
Go

package importutil
import (
"context"
"os"
"testing"
"github.com/sbinet/npyio/npy"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/timerecord"
)
func Test_NewNumpyParser(t *testing.T) {
ctx := context.Background()
parser := NewNumpyParser(ctx, nil, nil)
assert.Nil(t, parser)
}
func Test_ConvertNumpyType(t *testing.T) {
checkFunc := func(inputs []string, output schemapb.DataType) {
for i := 0; i < len(inputs); i++ {
dt, err := convertNumpyType(inputs[i])
assert.Nil(t, err)
assert.Equal(t, output, dt)
}
}
checkFunc([]string{"b1", "<b1", "|b1", "bool"}, schemapb.DataType_Bool)
checkFunc([]string{"i1", "<i1", "|i1", ">i1", "int8"}, schemapb.DataType_Int8)
checkFunc([]string{"i2", "<i2", "|i2", ">i2", "int16"}, schemapb.DataType_Int16)
checkFunc([]string{"i4", "<i4", "|i4", ">i4", "int32"}, schemapb.DataType_Int32)
checkFunc([]string{"i8", "<i8", "|i8", ">i8", "int64"}, schemapb.DataType_Int64)
checkFunc([]string{"f4", "<f4", "|f4", ">f4", "float32"}, schemapb.DataType_Float)
checkFunc([]string{"f8", "<f8", "|f8", ">f8", "float64"}, schemapb.DataType_Double)
dt, err := convertNumpyType("dummy")
assert.NotNil(t, err)
assert.Equal(t, schemapb.DataType_None, dt)
}
func Test_Validate(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
schema := sampleSchema()
flushFunc := func(field storage.FieldData) error {
return nil
}
adapter := &NumpyAdapter{npyReader: &npy.Reader{}}
{
// string type is not supported
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 109,
Name: "field_string",
IsPrimaryKey: false,
Description: "string",
DataType: schemapb.DataType_String,
},
},
}, flushFunc)
err = p.validate(adapter, "dummy")
assert.NotNil(t, err)
err = p.validate(adapter, "field_string")
assert.NotNil(t, err)
}
// reader is nil
parser := NewNumpyParser(ctx, schema, flushFunc)
err = parser.validate(nil, "")
assert.NotNil(t, err)
// validate scalar data
func() {
filePath := TempFilesPath + "scalar_1.npy"
data1 := []float64{0, 1, 2, 3, 4, 5}
CreateNumpyFile(filePath, data1)
file1, err := os.Open(filePath)
assert.Nil(t, err)
defer file1.Close()
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
err = parser.validate(adapter, "field_double")
assert.Nil(t, err)
assert.Equal(t, len(data1), parser.columnDesc.elementCount)
err = parser.validate(adapter, "")
assert.NotNil(t, err)
// data type mismatch
filePath = TempFilesPath + "scalar_2.npy"
data2 := []int64{0, 1, 2, 3, 4, 5}
CreateNumpyFile(filePath, data2)
file2, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
err = parser.validate(adapter, "field_double")
assert.NotNil(t, err)
// shape mismatch
filePath = TempFilesPath + "scalar_2.npy"
data3 := [][2]float64{{1, 1}}
CreateNumpyFile(filePath, data3)
file3, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
err = parser.validate(adapter, "field_double")
assert.NotNil(t, err)
}()
// validate binary vector data
func() {
filePath := TempFilesPath + "binary_vector_1.npy"
data1 := [][2]uint8{{0, 1}, {2, 3}, {4, 5}}
CreateNumpyFile(filePath, data1)
file1, err := os.Open(filePath)
assert.Nil(t, err)
defer file1.Close()
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.Nil(t, err)
assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
// data type mismatch
filePath = TempFilesPath + "binary_vector_2.npy"
data2 := [][2]uint16{{0, 1}, {2, 3}, {4, 5}}
CreateNumpyFile(filePath, data2)
file2, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
// shape mismatch
filePath = TempFilesPath + "binary_vector_3.npy"
data3 := []uint8{1, 2, 3}
CreateNumpyFile(filePath, data3)
file3, err := os.Open(filePath)
assert.Nil(t, err)
defer file3.Close()
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
// shape[1] mismatch
filePath = TempFilesPath + "binary_vector_4.npy"
data4 := [][3]uint8{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}
CreateNumpyFile(filePath, data4)
file4, err := os.Open(filePath)
assert.Nil(t, err)
defer file4.Close()
adapter, err = NewNumpyAdapter(file4)
assert.Nil(t, err)
err = parser.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
// dimension mismatch
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 109,
Name: "field_binary_vector",
DataType: schemapb.DataType_BinaryVector,
},
},
}, flushFunc)
err = p.validate(adapter, "field_binary_vector")
assert.NotNil(t, err)
}()
// validate float vector data
func() {
filePath := TempFilesPath + "float_vector.npy"
data1 := [][4]float32{{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}, {3, 3, 3, 3}}
CreateNumpyFile(filePath, data1)
file1, err := os.Open(filePath)
assert.Nil(t, err)
defer file1.Close()
adapter, err := NewNumpyAdapter(file1)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.Nil(t, err)
assert.Equal(t, len(data1)*len(data1[0]), parser.columnDesc.elementCount)
// data type mismatch
filePath = TempFilesPath + "float_vector_2.npy"
data2 := [][4]int32{{0, 1, 2, 3}}
CreateNumpyFile(filePath, data2)
file2, err := os.Open(filePath)
assert.Nil(t, err)
defer file2.Close()
adapter, err = NewNumpyAdapter(file2)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
// shape mismatch
filePath = TempFilesPath + "float_vector_3.npy"
data3 := []float32{1, 2, 3}
CreateNumpyFile(filePath, data3)
file3, err := os.Open(filePath)
assert.Nil(t, err)
defer file3.Close()
adapter, err = NewNumpyAdapter(file3)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
// shape[1] mismatch
filePath = TempFilesPath + "float_vector_4.npy"
data4 := [][3]float32{{0, 0, 0}, {1, 1, 1}}
CreateNumpyFile(filePath, data4)
file4, err := os.Open(filePath)
assert.Nil(t, err)
defer file4.Close()
adapter, err = NewNumpyAdapter(file4)
assert.Nil(t, err)
err = parser.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
// dimension mismatch
p := NewNumpyParser(ctx, &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 109,
Name: "field_float_vector",
DataType: schemapb.DataType_FloatVector,
},
},
}, flushFunc)
err = p.validate(adapter, "field_float_vector")
assert.NotNil(t, err)
}()
}
func Test_Parse(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
schema := sampleSchema()
checkFunc := func(data interface{}, fieldName string, callback func(field storage.FieldData) error) {
filePath := TempFilesPath + fieldName + ".npy"
CreateNumpyFile(filePath, data)
func() {
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
parser := NewNumpyParser(ctx, schema, callback)
err = parser.Parse(file, fieldName, false)
assert.Nil(t, err)
}()
// validation failed
func() {
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
parser := NewNumpyParser(ctx, schema, callback)
err = parser.Parse(file, "dummy", false)
assert.NotNil(t, err)
}()
// read data error
func() {
parser := NewNumpyParser(ctx, schema, callback)
err = parser.Parse(&MockReader{}, fieldName, false)
assert.NotNil(t, err)
}()
}
// scalar bool
data1 := []bool{true, false, true, false, true}
flushFunc := func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data1), field.RowNum())
for i := 0; i < len(data1); i++ {
assert.Equal(t, data1[i], field.GetRow(i))
}
return nil
}
checkFunc(data1, "field_bool", flushFunc)
// scalar int8
data2 := []int8{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data2), field.RowNum())
for i := 0; i < len(data2); i++ {
assert.Equal(t, data2[i], field.GetRow(i))
}
return nil
}
checkFunc(data2, "field_int8", flushFunc)
// scalar int16
data3 := []int16{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data3), field.RowNum())
for i := 0; i < len(data3); i++ {
assert.Equal(t, data3[i], field.GetRow(i))
}
return nil
}
checkFunc(data3, "field_int16", flushFunc)
// scalar int32
data4 := []int32{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data4), field.RowNum())
for i := 0; i < len(data4); i++ {
assert.Equal(t, data4[i], field.GetRow(i))
}
return nil
}
checkFunc(data4, "field_int32", flushFunc)
// scalar int64
data5 := []int64{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data5), field.RowNum())
for i := 0; i < len(data5); i++ {
assert.Equal(t, data5[i], field.GetRow(i))
}
return nil
}
checkFunc(data5, "field_int64", flushFunc)
// scalar float
data6 := []float32{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data6), field.RowNum())
for i := 0; i < len(data6); i++ {
assert.Equal(t, data6[i], field.GetRow(i))
}
return nil
}
checkFunc(data6, "field_float", flushFunc)
// scalar double
data7 := []float64{1, 2, 3, 4, 5}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data7), field.RowNum())
for i := 0; i < len(data7); i++ {
assert.Equal(t, data7[i], field.GetRow(i))
}
return nil
}
checkFunc(data7, "field_double", flushFunc)
// binary vector
data8 := [][2]uint8{{1, 2}, {3, 4}, {5, 6}}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data8), field.RowNum())
for i := 0; i < len(data8); i++ {
row := field.GetRow(i).([]uint8)
for k := 0; k < len(row); k++ {
assert.Equal(t, data8[i][k], row[k])
}
}
return nil
}
checkFunc(data8, "field_binary_vector", flushFunc)
// double vector(element can be float32 or float64)
data9 := [][4]float32{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data9), field.RowNum())
for i := 0; i < len(data9); i++ {
row := field.GetRow(i).([]float32)
for k := 0; k < len(row); k++ {
assert.Equal(t, data9[i][k], row[k])
}
}
return nil
}
checkFunc(data9, "field_float_vector", flushFunc)
data10 := [][4]float64{{1.1, 2.1, 3.1, 4.1}, {5.2, 6.2, 7.2, 8.2}}
flushFunc = func(field storage.FieldData) error {
assert.NotNil(t, field)
assert.Equal(t, len(data10), field.RowNum())
for i := 0; i < len(data10); i++ {
row := field.GetRow(i).([]float32)
for k := 0; k < len(row); k++ {
assert.Equal(t, float32(data10[i][k]), row[k])
}
}
return nil
}
checkFunc(data10, "field_float_vector", flushFunc)
}
func Test_Parse_perf(t *testing.T) {
ctx := context.Background()
err := os.MkdirAll(TempFilesPath, os.ModePerm)
assert.Nil(t, err)
defer os.RemoveAll(TempFilesPath)
tr := timerecord.NewTimeRecorder("numpy parse performance")
// change the parameter to test performance
rowCount := 10000
dotValue := float32(3.1415926)
const (
dim = 128
)
schema := perfSchema(dim)
data := make([][dim]float32, 0)
for i := 0; i < rowCount; i++ {
var row [dim]float32
for k := 0; k < dim; k++ {
row[k] = float32(i) + dotValue
}
data = append(data, row)
}
tr.Record("generate large data")
flushFunc := func(field storage.FieldData) error {
assert.Equal(t, len(data), field.RowNum())
return nil
}
filePath := TempFilesPath + "perf.npy"
CreateNumpyFile(filePath, data)
tr.Record("generate large numpy file " + filePath)
file, err := os.Open(filePath)
assert.Nil(t, err)
defer file.Close()
parser := NewNumpyParser(ctx, schema, flushFunc)
err = parser.Parse(file, "Vector", false)
assert.Nil(t, err)
tr.Record("parse large numpy files: " + filePath)
}