Verify vector float data for bulkinsert and insert (#22728)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
pull/22745/head
groot 2023-03-14 14:03:58 +08:00 committed by GitHub
parent ebc173cfb8
commit 6f6bd98c27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 292 additions and 4 deletions

View File

@ -101,6 +101,30 @@ func (it *insertTask) OnEnqueue() error {
return nil
}
func (it *insertTask) checkVectorFieldData() error {
fields := it.insertMsg.GetFieldsData()
for _, field := range fields {
if field.GetType() != schemapb.DataType_FloatVector {
continue
}
vectorField := field.GetVectors()
if vectorField == nil || vectorField.GetFloatVector() == nil {
log.Error("float vector field is illegal, array type mismatch", zap.String("field name", field.GetFieldName()))
return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
}
floatArray := vectorField.GetFloatVector()
err := typeutil.VerifyFloats32(floatArray.GetData())
if err != nil {
log.Error("float vector field data is illegal", zap.String("field name", field.GetFieldName()), zap.Error(err))
return fmt.Errorf("float vector field data is illegal, error: %w", err)
}
}
return nil
}
func (it *insertTask) PreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Insert-PreExecute")
defer sp.End()
@ -187,6 +211,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return err
}
// check vector field data
err = it.checkVectorFieldData()
if err != nil {
log.Error("vector field data is illegal", zap.Error(err))
return err
}
log.Debug("Proxy Insert PreExecute done")
return nil

View File

@ -1,6 +1,7 @@
package proxy
import (
"math"
"testing"
"github.com/stretchr/testify/assert"
@ -219,3 +220,84 @@ func TestInsertTask_CheckAligned(t *testing.T) {
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
}
func TestInsertTask_CheckVectorFieldData(t *testing.T) {
fieldName := "embeddings"
numRows := 10
dim := 32
task := insertTask{
insertMsg: &BaseInsertTask{
InsertRequest: msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
Version: msgpb.InsertDataVersion_ColumnBased,
NumRows: uint64(numRows),
},
},
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_CheckVectorFieldData",
Description: "TestInsertTask_CheckVectorFieldData",
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: fieldName,
IsPrimaryKey: false,
AutoID: false,
DataType: schemapb.DataType_FloatVector,
},
},
},
}
// success case
task.insertMsg.FieldsData = []*schemapb.FieldData{
newFloatVectorFieldData(fieldName, numRows, dim),
}
err := task.checkVectorFieldData()
assert.NoError(t, err)
// field is nil
task.insertMsg.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_FloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: nil,
},
},
}
err = task.checkVectorFieldData()
assert.Error(t, err)
// vector data is not a number
values := generateFloatVectors(numRows, dim)
values[5] = float32(math.NaN())
task.insertMsg.FieldsData[0].Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: values,
},
},
},
}
err = task.checkVectorFieldData()
assert.Error(t, err)
// vector data is infinity
values[5] = float32(math.Inf(1))
task.insertMsg.FieldsData[0].Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: values,
},
},
},
}
err = task.checkVectorFieldData()
assert.Error(t, err)
}

View File

@ -20,7 +20,6 @@ import (
"context"
"encoding/json"
"fmt"
"math"
"path"
"runtime/debug"
"strconv"
@ -120,9 +119,9 @@ func parseFloat(s string, bitsize int, fieldName string) (float64, error) {
return 0, fmt.Errorf("failed to parse value '%s' for field '%s', error: %w", s, fieldName, err)
}
// not allow not-a-number and infinity
if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) {
return 0, fmt.Errorf("value '%s' is not a number or infinity, field '%s', error: %w", s, fieldName, err)
err = typeutil.VerifyFloat(value)
if err != nil {
return 0, fmt.Errorf("illegal value '%s' for field '%s', error: %w", s, fieldName, err)
}
return value, nil

View File

@ -299,6 +299,14 @@ func Test_parseFloat(t *testing.T) {
value, err = parseFloat("2.718281828459045", 64, "")
assert.True(t, math.Abs(value-2.718281828459045) < 0.0000000000000001)
assert.Nil(t, err)
value, err = parseFloat("Inf", 32, "")
assert.Zero(t, value)
assert.Error(t, err)
value, err = parseFloat("NaN", 64, "")
assert.Zero(t, value)
assert.Error(t, err)
}
func Test_InitValidators(t *testing.T) {

View File

@ -494,6 +494,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
return nil, fmt.Errorf("failed to read float array: %s", err.Error())
}
err = typeutil.VerifyFloats32(data)
if err != nil {
log.Error("Numpy parser: illegal value in float array", zap.Error(err))
return nil, fmt.Errorf("illegal value in float array: %s", err.Error())
}
return &storage.FloatFieldData{
Data: data,
}, nil
@ -504,6 +510,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
return nil, fmt.Errorf("failed to read double array: %s", err.Error())
}
err = typeutil.VerifyFloats64(data)
if err != nil {
log.Error("Numpy parser: illegal value in double array", zap.Error(err))
return nil, fmt.Errorf("illegal value in double array: %s", err.Error())
}
return &storage.DoubleFieldData{
Data: data,
}, nil
@ -542,6 +554,13 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
log.Error("Numpy parser: failed to read float vector array", zap.Error(err))
return nil, fmt.Errorf("failed to read float vector array: %s", err.Error())
}
err = typeutil.VerifyFloats32(data)
if err != nil {
log.Error("Numpy parser: illegal value in float vector array", zap.Error(err))
return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error())
}
} else if elementType == schemapb.DataType_Double {
data = make([]float32, 0, columnReader.rowCount)
data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension)
@ -551,6 +570,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
}
for _, f64 := range data64 {
err = typeutil.VerifyFloat(f64)
if err != nil {
log.Error("Numpy parser: illegal value in float vector array", zap.Error(err))
return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error())
}
data = append(data, float32(f64))
}
}

View File

@ -18,6 +18,7 @@ package importutil
import (
"context"
"math"
"os"
"testing"
@ -403,6 +404,22 @@ func Test_NumpyParserReadData(t *testing.T) {
}
}
readErrorFunc := func(filedName string, data interface{}) {
filePath := TempFilesPath + filedName + ".npy"
err = CreateNumpyFile(filePath, data)
assert.Nil(t, err)
readers, err := parser.createReaders([]string{filePath})
assert.NoError(t, err)
assert.Equal(t, 1, len(readers))
defer closeReaders(readers)
// encounter error
fieldData, err := parser.readData(readers[0], 1000)
assert.Error(t, err)
assert.Nil(t, fieldData)
}
t.Run("read bool", func(t *testing.T) {
readEmptyFunc("FieldBool", []bool{})
@ -443,6 +460,8 @@ func Test_NumpyParserReadData(t *testing.T) {
data := []float32{2.5, 32.2, 53.254, 3.45, 65.23421, 54.8978}
readBatchFunc("FieldFloat", data, len(data), func(k int) interface{} { return data[k] })
data = []float32{2.5, 32.2, float32(math.NaN())}
readErrorFunc("FieldFloat", data)
})
t.Run("read double", func(t *testing.T) {
@ -450,6 +469,8 @@ func Test_NumpyParserReadData(t *testing.T) {
data := []float64{65.24454, 343.4365, 432.6556}
readBatchFunc("FieldDouble", data, len(data), func(k int) interface{} { return data[k] })
data = []float64{65.24454, math.Inf(1)}
readErrorFunc("FieldDouble", data)
})
specialReadEmptyFunc := func(filedName string, data interface{}) {
@ -482,6 +503,9 @@ func Test_NumpyParserReadData(t *testing.T) {
t.Run("read float vector", func(t *testing.T) {
specialReadEmptyFunc("FieldFloatVector", [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}})
specialReadEmptyFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, 5, 6}})
readErrorFunc("FieldFloatVector", [][4]float32{{1, 2, 3, float32(math.NaN())}, {3, 4, 5, 6}})
readErrorFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, math.Inf(1), 6}})
})
}

View File

@ -0,0 +1,53 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package typeutil
import (
"fmt"
"math"
)
func VerifyFloat(value float64) error {
// not allow not-a-number and infinity
if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) {
return fmt.Errorf("value '%f' is not a number or infinity", value)
}
return nil
}
func VerifyFloats32(values []float32) error {
for _, f := range values {
err := VerifyFloat(float64(f))
if err != nil {
return err
}
}
return nil
}
func VerifyFloats64(values []float64) error {
for _, f := range values {
err := VerifyFloat(f)
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,66 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package typeutil
import (
"math"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_VerifyFloat(t *testing.T) {
var value = math.NaN()
err := VerifyFloat(value)
assert.Error(t, err)
value = math.Inf(1)
err = VerifyFloat(value)
assert.Error(t, err)
value = math.Inf(-1)
err = VerifyFloat(value)
assert.Error(t, err)
}
func Test_VerifyFloats32(t *testing.T) {
data := []float32{2.5, 32.2, 53.254}
err := VerifyFloats32(data)
assert.NoError(t, err)
data = []float32{2.5, 32.2, 53.254, float32(math.NaN())}
err = VerifyFloats32(data)
assert.Error(t, err)
data = []float32{2.5, 32.2, 53.254, float32(math.Inf(1))}
err = VerifyFloats32(data)
assert.Error(t, err)
}
func Test_VerifyFloats64(t *testing.T) {
data := []float64{2.5, 32.2, 53.254}
err := VerifyFloats64(data)
assert.NoError(t, err)
data = []float64{2.5, 32.2, 53.254, math.NaN()}
err = VerifyFloats64(data)
assert.Error(t, err)
data = []float64{2.5, 32.2, 53.254, math.Inf(-1)}
err = VerifyFloats64(data)
assert.Error(t, err)
}