mirror of https://github.com/milvus-io/milvus.git
Verify vector float data for bulkinsert and insert (#22728)
Signed-off-by: yhmo <yihua.mo@zilliz.com>pull/22745/head
parent
ebc173cfb8
commit
6f6bd98c27
|
@ -101,6 +101,30 @@ func (it *insertTask) OnEnqueue() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) checkVectorFieldData() error {
|
||||
fields := it.insertMsg.GetFieldsData()
|
||||
for _, field := range fields {
|
||||
if field.GetType() != schemapb.DataType_FloatVector {
|
||||
continue
|
||||
}
|
||||
|
||||
vectorField := field.GetVectors()
|
||||
if vectorField == nil || vectorField.GetFloatVector() == nil {
|
||||
log.Error("float vector field is illegal, array type mismatch", zap.String("field name", field.GetFieldName()))
|
||||
return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
}
|
||||
|
||||
floatArray := vectorField.GetFloatVector()
|
||||
err := typeutil.VerifyFloats32(floatArray.GetData())
|
||||
if err != nil {
|
||||
log.Error("float vector field data is illegal", zap.String("field name", field.GetFieldName()), zap.Error(err))
|
||||
return fmt.Errorf("float vector field data is illegal, error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Insert-PreExecute")
|
||||
defer sp.End()
|
||||
|
@ -187,6 +211,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// check vector field data
|
||||
err = it.checkVectorFieldData()
|
||||
if err != nil {
|
||||
log.Error("vector field data is illegal", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Proxy Insert PreExecute done")
|
||||
|
||||
return nil
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -219,3 +220,84 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
|||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInsertTask_CheckVectorFieldData(t *testing.T) {
|
||||
fieldName := "embeddings"
|
||||
numRows := 10
|
||||
dim := 32
|
||||
task := insertTask{
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||
NumRows: uint64(numRows),
|
||||
},
|
||||
},
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_CheckVectorFieldData",
|
||||
Description: "TestInsertTask_CheckVectorFieldData",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: fieldName,
|
||||
IsPrimaryKey: false,
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// success case
|
||||
task.insertMsg.FieldsData = []*schemapb.FieldData{
|
||||
newFloatVectorFieldData(fieldName, numRows, dim),
|
||||
}
|
||||
err := task.checkVectorFieldData()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// field is nil
|
||||
task.insertMsg.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = task.checkVectorFieldData()
|
||||
assert.Error(t, err)
|
||||
|
||||
// vector data is not a number
|
||||
values := generateFloatVectors(numRows, dim)
|
||||
values[5] = float32(math.NaN())
|
||||
task.insertMsg.FieldsData[0].Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: values,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err = task.checkVectorFieldData()
|
||||
assert.Error(t, err)
|
||||
|
||||
// vector data is infinity
|
||||
values[5] = float32(math.Inf(1))
|
||||
task.insertMsg.FieldsData[0].Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: values,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err = task.checkVectorFieldData()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
|
@ -120,9 +119,9 @@ func parseFloat(s string, bitsize int, fieldName string) (float64, error) {
|
|||
return 0, fmt.Errorf("failed to parse value '%s' for field '%s', error: %w", s, fieldName, err)
|
||||
}
|
||||
|
||||
// not allow not-a-number and infinity
|
||||
if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) {
|
||||
return 0, fmt.Errorf("value '%s' is not a number or infinity, field '%s', error: %w", s, fieldName, err)
|
||||
err = typeutil.VerifyFloat(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("illegal value '%s' for field '%s', error: %w", s, fieldName, err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
|
|
|
@ -299,6 +299,14 @@ func Test_parseFloat(t *testing.T) {
|
|||
value, err = parseFloat("2.718281828459045", 64, "")
|
||||
assert.True(t, math.Abs(value-2.718281828459045) < 0.0000000000000001)
|
||||
assert.Nil(t, err)
|
||||
|
||||
value, err = parseFloat("Inf", 32, "")
|
||||
assert.Zero(t, value)
|
||||
assert.Error(t, err)
|
||||
|
||||
value, err = parseFloat("NaN", 64, "")
|
||||
assert.Zero(t, value)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_InitValidators(t *testing.T) {
|
||||
|
|
|
@ -494,6 +494,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
|||
return nil, fmt.Errorf("failed to read float array: %s", err.Error())
|
||||
}
|
||||
|
||||
err = typeutil.VerifyFloats32(data)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in float array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in float array: %s", err.Error())
|
||||
}
|
||||
|
||||
return &storage.FloatFieldData{
|
||||
Data: data,
|
||||
}, nil
|
||||
|
@ -504,6 +510,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
|||
return nil, fmt.Errorf("failed to read double array: %s", err.Error())
|
||||
}
|
||||
|
||||
err = typeutil.VerifyFloats64(data)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in double array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in double array: %s", err.Error())
|
||||
}
|
||||
|
||||
return &storage.DoubleFieldData{
|
||||
Data: data,
|
||||
}, nil
|
||||
|
@ -542,6 +554,13 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
|||
log.Error("Numpy parser: failed to read float vector array", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read float vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
err = typeutil.VerifyFloats32(data)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in float vector array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
} else if elementType == schemapb.DataType_Double {
|
||||
data = make([]float32, 0, columnReader.rowCount)
|
||||
data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension)
|
||||
|
@ -551,6 +570,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
|||
}
|
||||
|
||||
for _, f64 := range data64 {
|
||||
err = typeutil.VerifyFloat(f64)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in float vector array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
data = append(data, float32(f64))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package importutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
|
@ -403,6 +404,22 @@ func Test_NumpyParserReadData(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
readErrorFunc := func(filedName string, data interface{}) {
|
||||
filePath := TempFilesPath + filedName + ".npy"
|
||||
err = CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
readers, err := parser.createReaders([]string{filePath})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(readers))
|
||||
defer closeReaders(readers)
|
||||
|
||||
// encounter error
|
||||
fieldData, err := parser.readData(readers[0], 1000)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, fieldData)
|
||||
}
|
||||
|
||||
t.Run("read bool", func(t *testing.T) {
|
||||
readEmptyFunc("FieldBool", []bool{})
|
||||
|
||||
|
@ -443,6 +460,8 @@ func Test_NumpyParserReadData(t *testing.T) {
|
|||
|
||||
data := []float32{2.5, 32.2, 53.254, 3.45, 65.23421, 54.8978}
|
||||
readBatchFunc("FieldFloat", data, len(data), func(k int) interface{} { return data[k] })
|
||||
data = []float32{2.5, 32.2, float32(math.NaN())}
|
||||
readErrorFunc("FieldFloat", data)
|
||||
})
|
||||
|
||||
t.Run("read double", func(t *testing.T) {
|
||||
|
@ -450,6 +469,8 @@ func Test_NumpyParserReadData(t *testing.T) {
|
|||
|
||||
data := []float64{65.24454, 343.4365, 432.6556}
|
||||
readBatchFunc("FieldDouble", data, len(data), func(k int) interface{} { return data[k] })
|
||||
data = []float64{65.24454, math.Inf(1)}
|
||||
readErrorFunc("FieldDouble", data)
|
||||
})
|
||||
|
||||
specialReadEmptyFunc := func(filedName string, data interface{}) {
|
||||
|
@ -482,6 +503,9 @@ func Test_NumpyParserReadData(t *testing.T) {
|
|||
t.Run("read float vector", func(t *testing.T) {
|
||||
specialReadEmptyFunc("FieldFloatVector", [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}})
|
||||
specialReadEmptyFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, 5, 6}})
|
||||
|
||||
readErrorFunc("FieldFloatVector", [][4]float32{{1, 2, 3, float32(math.NaN())}, {3, 4, 5, 6}})
|
||||
readErrorFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, math.Inf(1), 6}})
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
func VerifyFloat(value float64) error {
|
||||
// not allow not-a-number and infinity
|
||||
if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) {
|
||||
return fmt.Errorf("value '%f' is not a number or infinity", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func VerifyFloats32(values []float32) error {
|
||||
for _, f := range values {
|
||||
err := VerifyFloat(float64(f))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func VerifyFloats64(values []float64) error {
|
||||
for _, f := range values {
|
||||
err := VerifyFloat(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_VerifyFloat(t *testing.T) {
|
||||
var value = math.NaN()
|
||||
err := VerifyFloat(value)
|
||||
assert.Error(t, err)
|
||||
|
||||
value = math.Inf(1)
|
||||
err = VerifyFloat(value)
|
||||
assert.Error(t, err)
|
||||
|
||||
value = math.Inf(-1)
|
||||
err = VerifyFloat(value)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_VerifyFloats32(t *testing.T) {
|
||||
data := []float32{2.5, 32.2, 53.254}
|
||||
err := VerifyFloats32(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data = []float32{2.5, 32.2, 53.254, float32(math.NaN())}
|
||||
err = VerifyFloats32(data)
|
||||
assert.Error(t, err)
|
||||
|
||||
data = []float32{2.5, 32.2, 53.254, float32(math.Inf(1))}
|
||||
err = VerifyFloats32(data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_VerifyFloats64(t *testing.T) {
|
||||
data := []float64{2.5, 32.2, 53.254}
|
||||
err := VerifyFloats64(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data = []float64{2.5, 32.2, 53.254, math.NaN()}
|
||||
err = VerifyFloats64(data)
|
||||
assert.Error(t, err)
|
||||
|
||||
data = []float64{2.5, 32.2, 53.254, math.Inf(-1)}
|
||||
err = VerifyFloats64(data)
|
||||
assert.Error(t, err)
|
||||
}
|
Loading…
Reference in New Issue