enhance: check fp16/bf16 nan or inf value (#31840)

issue:https://github.com/milvus-io/milvus/issues/22837

Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
pull/31816/head
cqy123456 2024-04-09 01:19:27 -05:00 committed by GitHub
parent 1b767669a4
commit 8fda3cbeda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 305 additions and 13 deletions

View File

@ -18,10 +18,13 @@ package proxy
import (
"context"
"encoding/binary"
"math"
"math/rand"
"sync"
"time"
"github.com/x448/float16"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -527,21 +530,63 @@ func generateBinaryVectors(numRows, dim int) []byte {
}
func generateFloat16Vectors(numRows, dim int) []byte {
total := numRows * dim * 2
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
total := numRows * dim
ret := make([]byte, total*2)
for i := 0; i < total; i++ {
v := float16.Fromfloat32(rand.Float32()).Bits()
binary.LittleEndian.PutUint16(ret[i*2:], v)
}
return ret
}
func generateBFloat16Vectors(numRows, dim int) []byte {
total := numRows * dim * 2
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
total := numRows * dim
ret16 := make([]uint16, 0, total)
for i := 0; i < total; i++ {
f := rand.Float32()
bits := math.Float32bits(f)
bits >>= 16
bits &= 0x7FFF
ret16 = append(ret16, uint16(bits))
}
ret := make([]byte, len(ret16)*2)
for i, value := range ret16 {
binary.LittleEndian.PutUint16(ret[i*2:], value)
}
return ret
}
func generateBFloat16VectorsWithInvalidData(numRows, dim int) []byte {
total := numRows * dim
ret16 := make([]uint16, 0, total)
for i := 0; i < total; i++ {
var f float32
if i%2 == 0 {
f = float32(math.NaN())
} else {
f = float32(math.Inf(1))
}
bits := math.Float32bits(f)
bits >>= 16
bits &= 0x7FFF
ret16 = append(ret16, uint16(bits))
}
ret := make([]byte, len(ret16)*2)
for i, value := range ret16 {
binary.LittleEndian.PutUint16(ret[i*2:], value)
}
return ret
}
func generateFloat16VectorsWithInvalidData(numRows, dim int) []byte {
total := numRows * dim
ret := make([]byte, total*2)
for i := 0; i < total; i++ {
if i%2 == 0 {
binary.LittleEndian.PutUint16(ret[i*2:], uint16(float16.Inf(1)))
} else {
binary.LittleEndian.PutUint16(ret[i*2:], uint16(float16.NaN()))
}
}
return ret
}
@ -551,7 +596,6 @@ func generateVarCharArray(numRows int, maxLen int) []string {
for i := 0; i < numRows; i++ {
ret[i] = funcutil.RandomString(rand.Intn(maxLen))
}
return ret
}

View File

@ -342,12 +342,26 @@ func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fiel
}
func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
// TODO
float16VecArray := field.GetVectors().GetFloat16Vector()
if float16VecArray == nil {
msg := fmt.Sprintf("float16 float field '%v' is illegal, nil Vector_Float16 type", field.GetFieldName())
return merr.WrapErrParameterInvalid("need vector_float16 array", "got nil", msg)
}
if v.checkNAN {
return typeutil.VerifyFloats16(float16VecArray)
}
return nil
}
func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
// TODO
bfloat16VecArray := field.GetVectors().GetBfloat16Vector()
if bfloat16VecArray == nil {
msg := fmt.Sprintf("bfloat16 float field '%v' is illegal, nil Vector_BFloat16 type", field.GetFieldName())
return merr.WrapErrParameterInvalid("need vector_bfloat16 array", "got nil", msg)
}
if v.checkNAN {
return typeutil.VerifyBFloats16(bfloat16VecArray)
}
return nil
}

View File

@ -270,6 +270,187 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
})
}
func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) {
nb := 5
dim := int64(8)
data := generateFloat16Vectors(nb, int(dim))
invalidData := generateFloat16VectorsWithInvalidData(nb, int(dim))
t.Run("not float16 vector", func(t *testing.T) {
f := &schemapb.FieldData{}
v := newValidateUtil()
err := v.checkFloat16VectorFieldData(f, nil)
assert.Error(t, err)
})
t.Run("no check", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: invalidData,
},
},
},
}
v := newValidateUtil()
v.checkNAN = false
err := v.checkFloat16VectorFieldData(f, nil)
assert.NoError(t, err)
})
t.Run("has nan", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: invalidData,
},
},
},
}
v := newValidateUtil(withNANCheck())
err := v.checkFloat16VectorFieldData(f, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: data,
},
},
},
}
v := newValidateUtil(withNANCheck())
err := v.checkFloat16VectorFieldData(f, nil)
assert.NoError(t, err)
})
t.Run("default", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldId: 100,
FieldName: "vec",
Type: schemapb.DataType_Float16Vector,
Field: &schemapb.FieldData_Vectors{},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "vec",
DataType: schemapb.DataType_Float16Vector,
DefaultValue: &schemapb.ValueField{},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.fillWithDefaultValue(data, h, 1)
assert.Error(t, err)
})
}
func Test_validateUtil_checkBfloatVectorFieldData(t *testing.T) {
nb := 5
dim := int64(8)
data := generateFloat16Vectors(nb, int(dim))
invalidData := generateBFloat16VectorsWithInvalidData(nb, int(dim))
t.Run("not float vector", func(t *testing.T) {
f := &schemapb.FieldData{}
v := newValidateUtil()
err := v.checkBFloat16VectorFieldData(f, nil)
assert.Error(t, err)
})
t.Run("no check", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: invalidData,
},
},
},
}
v := newValidateUtil()
v.checkNAN = false
err := v.checkBFloat16VectorFieldData(f, nil)
assert.NoError(t, err)
})
t.Run("has nan", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: invalidData,
},
},
},
}
v := newValidateUtil(withNANCheck())
err := v.checkBFloat16VectorFieldData(f, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
f := &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: data,
},
},
},
}
v := newValidateUtil(withNANCheck())
err := v.checkBFloat16VectorFieldData(f, nil)
assert.NoError(t, err)
})
t.Run("default", func(t *testing.T) {
data := []*schemapb.FieldData{
{
FieldId: 100,
FieldName: "vec",
Type: schemapb.DataType_BFloat16Vector,
Field: &schemapb.FieldData_Vectors{},
},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "vec",
DataType: schemapb.DataType_BFloat16Vector,
DefaultValue: &schemapb.ValueField{},
},
},
}
h, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
v := newValidateUtil()
err = v.fillWithDefaultValue(data, h, 1)
assert.Error(t, err)
})
}
func Test_validateUtil_checkAligned(t *testing.T) {
t.Run("float vector column not found", func(t *testing.T) {
data := []*schemapb.FieldData{

View File

@ -17,10 +17,35 @@
package typeutil
import (
"encoding/binary"
"fmt"
"math"
)
func bfloat16IsNaN(f uint16) bool {
// the nan value of bfloat16 is x111 1111 1xxx xxxx
return (f&0x7F80 == 0x7F80) && (f&0x007f != 0)
}
func bfloat16IsInf(f uint16, sign int) bool {
// +inf: 0111 1111 1000 0000
// -inf: 1111 1111 1000 0000
return ((f == 0x7F80) && sign >= 0) ||
(f == 0xFF80 && sign <= 0)
}
func float16IsNaN(f uint16) bool {
// the nan value of bfloat16 is x111 1100 0000 0000
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0)
}
func float16IsInf(f uint16, sign int) bool {
// +inf: 0111 1100 0000 0000
// -inf: 1111 1100 0000 0000
return ((f == 0x7c00) && sign >= 0) ||
(f == 0xfc00 && sign <= 0)
}
func VerifyFloat(value float64) error {
// not allow not-a-number and infinity
if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) {
@ -51,3 +76,31 @@ func VerifyFloats64(values []float64) error {
return nil
}
func VerifyFloats16(value []byte) error {
if len(value)%2 != 0 {
return fmt.Errorf("The length of float16 is not aligned to 2.")
}
dataSize := len(value) / 2
for i := 0; i < dataSize; i++ {
v := binary.LittleEndian.Uint16(value[i*2:])
if float16IsNaN(v) || float16IsInf(v, -1) || float16IsInf(v, 1) {
return fmt.Errorf("float16 vector contain nan or infinity value.")
}
}
return nil
}
func VerifyBFloats16(value []byte) error {
if len(value)%2 != 0 {
return fmt.Errorf("The length of bfloat16 in not aligned to 2")
}
dataSize := len(value) / 2
for i := 0; i < dataSize; i++ {
v := binary.LittleEndian.Uint16(value[i*2:])
if bfloat16IsNaN(v) || bfloat16IsInf(v, -1) || bfloat16IsInf(v, 1) {
return fmt.Errorf("bfloat16 vector contain nan or infinity value.")
}
}
return nil
}