fix: fix fp16/bf16 some code missing and add more fp16/bf16 test (#31612)

issue: #31534

Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
pull/31680/head
cqy123456 2024-03-28 01:11:10 -05:00 committed by GitHub
parent 3d66670619
commit 976928ecd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 293 additions and 13 deletions

1
go.mod
View File

@ -200,6 +200,7 @@ require (
github.com/twmb/murmur3 v1.1.3 // indirect
github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect

2
go.sum
View File

@ -868,6 +868,8 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=

View File

@ -28,6 +28,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/x448/float16"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -278,6 +279,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
fieldArray := genConstantFieldSchema(simpleArrayField)
floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField)
binVecFieldSchema := genVectorFieldSchema(simpleBinVecField)
float16VecFieldSchema := genVectorFieldSchema(simpleFloat16VecField)
bfloat16VecFieldSchema := genVectorFieldSchema(simpleBFloat16VecField)
var pkFieldSchema *schemapb.FieldSchema
switch pkType {
@ -302,6 +305,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
binVecFieldSchema,
pkFieldSchema,
fieldArray,
float16VecFieldSchema,
bfloat16VecFieldSchema,
},
}
@ -330,7 +335,7 @@ func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema)
TypeParams: field.GetTypeParams(),
}
switch field.GetDataType() {
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector:
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
{
index.IndexParams = []*commonpb.KeyValuePair{
{Key: common.MetricTypeKey, Value: metric.L2},
@ -500,21 +505,28 @@ func generateBinaryVectors(numRows, dim int) []byte {
}
func generateFloat16Vectors(numRows, dim int) []byte {
total := numRows * dim * 2
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
total := numRows * dim
ret := make([]byte, total*2)
for i := 0; i < total; i++ {
v := float16.Fromfloat32(rand.Float32()).Bits()
binary.LittleEndian.PutUint16(ret[i*2:], v)
}
return ret
}
func generateBFloat16Vectors(numRows, dim int) []byte {
total := numRows * dim * 2
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
total := numRows * dim
ret16 := make([]uint16, 0, total)
for i := 0; i < total; i++ {
f := rand.Float32()
bits := math.Float32bits(f)
bits >>= 16
bits &= 0x7FFF
ret16 = append(ret16, uint16(bits))
}
ret := make([]byte, len(ret16)*2)
for i, value := range ret16 {
binary.LittleEndian.PutUint16(ret[i*2:], value)
}
return ret
}
@ -1009,6 +1021,10 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim))
case schemapb.DataType_FloatVector:
dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim))
case schemapb.DataType_Float16Vector:
dataset = indexcgowrapper.GenFloat16VecDataset(generateFloat16Vectors(msgLength, defaultDim))
case schemapb.DataType_BFloat16Vector:
dataset = indexcgowrapper.GenBFloat16VecDataset(generateBFloat16Vectors(msgLength, defaultDim))
case schemapb.DataType_SparseFloatVector:
data := testutils.GenerateSparseFloatVectors(msgLength)
dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{
@ -1260,7 +1276,7 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
var fieldID int64
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector || f.DataType == schemapb.DataType_Float16Vector || f.DataType == schemapb.DataType_BFloat16Vector {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
fieldID = f.FieldID
for _, p := range f.IndexParams {

View File

@ -93,6 +93,28 @@ func getPKsFromRowBasedInsertMsg(msg *msgstream.InsertMsg, schema *schemapb.Coll
break
}
}
case schemapb.DataType_Float16Vector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim * 2
break
}
}
case schemapb.DataType_BFloat16Vector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim * 2
break
}
}
case schemapb.DataType_SparseFloatVector:
return nil, fmt.Errorf("SparseFloatVector not support in row based message")
}
@ -280,6 +302,10 @@ func fillFieldData(ctx context.Context, vcm storage.ChunkManager, dataPath strin
return fillBinVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_FloatVector:
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_Float16Vector:
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_BFloat16Vector:
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_SparseFloatVector:
return fillSparseFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_Bool:

View File

@ -307,6 +307,37 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
}
fmt.Println()
}
case schemapb.DataType_Float16Vector:
val, dim, err := reader.GetFloat16VectorFromPayload()
if err != nil {
return err
}
dim = dim * 2
length := len(val) / dim
for i := 0; i < length; i++ {
fmt.Printf("\t\t%d :", i)
for j := 0; j < dim; j++ {
idx := i*dim + j
fmt.Printf(" %02x", val[idx])
}
fmt.Println()
}
case schemapb.DataType_BFloat16Vector:
val, dim, err := reader.GetBFloat16VectorFromPayload()
if err != nil {
return err
}
dim = dim * 2
length := len(val) / dim
for i := 0; i < length; i++ {
fmt.Printf("\t\t%d :", i)
for j := 0; j < dim; j++ {
idx := i*dim + j
fmt.Printf(" %02x", val[idx])
}
fmt.Println()
}
case schemapb.DataType_FloatVector:
val, dim, err := reader.GetFloatVectorFromPayload()
if err != nil {

View File

@ -184,6 +184,20 @@ func TestPrintBinlogFiles(t *testing.T) {
Description: "description_12",
DataType: schemapb.DataType_JSON,
},
{
FieldID: 111,
Name: "field_bfloat16_vector",
IsPrimaryKey: false,
Description: "description_13",
DataType: schemapb.DataType_BFloat16Vector,
},
{
FieldID: 112,
Name: "field_float16_vector",
IsPrimaryKey: false,
Description: "description_14",
DataType: schemapb.DataType_Float16Vector,
},
},
},
}
@ -234,6 +248,14 @@ func TestPrintBinlogFiles(t *testing.T) {
[]byte(`{"key":"hello"}`),
},
},
111: &BFloat16VectorFieldData{
Data: []byte("12345678"),
Dim: 4,
},
112: &Float16VectorFieldData{
Data: []byte("12345678"),
Dim: 4,
},
},
}
@ -283,6 +305,14 @@ func TestPrintBinlogFiles(t *testing.T) {
[]byte(`{"key":"world"}`),
},
},
111: &BFloat16VectorFieldData{
Data: []byte("abcdefgh"),
Dim: 4,
},
112: &Float16VectorFieldData{
Data: []byte("abcdefgh"),
Dim: 4,
},
},
}
firstBlobs, err := insertCodec.Serialize(1, 1, insertDataFirst)

View File

@ -1201,6 +1201,32 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
}
case *Float16VectorFieldData:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_Float16Vector,
FieldId: fieldID,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: rawData.Data,
},
Dim: int64(rawData.Dim),
},
},
}
case *BFloat16VectorFieldData:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_BFloat16Vector,
FieldId: fieldID,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: rawData.Data,
},
Dim: int64(rawData.Dim),
},
},
}
case *SparseFloatVectorFieldData:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_SparseFloatVector,

View File

@ -993,6 +993,15 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) {
}
}
func TestRowBasedTransferInsertMsgToInsertRecord(t *testing.T) {
numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim := 10, 8, 8, 8, 8
schema, _, _ := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim, bf16VecDim, false)
msg, _, _ := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim)
_, err := TransferInsertMsgToInsertRecord(schema, msg)
assert.NoError(t, err)
}
func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) {
msg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{

View File

@ -89,6 +89,15 @@ func ConvertToArrowSchema(fields []*schemapb.FieldSchema) (*arrow.Schema, error)
Name: field.Name,
Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
})
case schemapb.DataType_BFloat16Vector:
dim, err := storage.GetDimFromParams(field.TypeParams)
if err != nil {
return nil, err
}
arrowFields = append(arrowFields, arrow.Field{
Name: field.Name,
Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
})
default:
return nil, merr.WrapErrParameterInvalidMsg("unknown type %v", field.DataType.String())
}

View File

@ -41,9 +41,33 @@ func TestConvertArrowSchema(t *testing.T) {
{FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
{FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON},
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
}
schema, err := ConvertToArrowSchema(fieldSchemas)
assert.NoError(t, err)
assert.Equal(t, len(fieldSchemas), len(schema.Fields()))
}
func TestConvertArrowSchemaWithoutDim(t *testing.T) {
fieldSchemas := []*schemapb.FieldSchema{
{FieldID: 1, Name: "field0", DataType: schemapb.DataType_Bool},
{FieldID: 2, Name: "field1", DataType: schemapb.DataType_Int8},
{FieldID: 3, Name: "field2", DataType: schemapb.DataType_Int16},
{FieldID: 4, Name: "field3", DataType: schemapb.DataType_Int32},
{FieldID: 5, Name: "field4", DataType: schemapb.DataType_Int64},
{FieldID: 6, Name: "field5", DataType: schemapb.DataType_Float},
{FieldID: 7, Name: "field6", DataType: schemapb.DataType_Double},
{FieldID: 8, Name: "field7", DataType: schemapb.DataType_String},
{FieldID: 9, Name: "field8", DataType: schemapb.DataType_VarChar},
{FieldID: 10, Name: "field9", DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 11, Name: "field10", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
{FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON},
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{}},
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{}},
}
_, err := ConvertToArrowSchema(fieldSchemas)
assert.Error(t, err)
}

View File

@ -108,6 +108,18 @@ func BuildRecord(b *array.RecordBuilder, data *storage.InsertData, fields []*sch
byteLength := dim * 2
length := len(data) / byteLength
builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
case schemapb.DataType_BFloat16Vector:
vecData := data.Data[field.FieldID].(*storage.BFloat16VectorFieldData)
builder := fBuilder.(*array.FixedSizeBinaryBuilder)
dim := vecData.Dim
data := vecData.Data
byteLength := dim * 2
length := len(data) / byteLength
builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])

View File

@ -146,7 +146,7 @@ func CheckCtxValid(ctx context.Context) bool {
func GetVecFieldIDs(schema *schemapb.CollectionSchema) []int64 {
var vecFieldIDs []int64
for _, field := range schema.Fields {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_SparseFloatVector {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_BFloat16Vector || field.DataType == schemapb.DataType_SparseFloatVector {
vecFieldIDs = append(vecFieldIDs, field.FieldID)
}
}

View File

@ -241,6 +241,10 @@ func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, e
res += int(fs.GetVectors().GetDim())
case schemapb.DataType_FloatVector:
res += int(fs.GetVectors().GetDim() * 4)
case schemapb.DataType_Float16Vector:
res += int(fs.GetVectors().GetDim() * 2)
case schemapb.DataType_BFloat16Vector:
res += int(fs.GetVectors().GetDim() * 2)
case schemapb.DataType_SparseFloatVector:
vec := fs.GetVectors().GetSparseFloatVector()
// counting only the size of the vector data, ignoring other
@ -527,6 +531,10 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap
vectors.Vectors.Data = &schemapb.VectorField_Float16Vector{
Float16Vector: make([]byte, 0, topK*dim*2),
}
case *schemapb.VectorField_Bfloat16Vector:
vectors.Vectors.Data = &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: make([]byte, 0, topK*dim*2),
}
case *schemapb.VectorField_BinaryVector:
vectors.Vectors.Data = &schemapb.VectorField_BinaryVector{
BinaryVector: make([]byte, 0, topK*dim/8),
@ -957,6 +965,24 @@ func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error
dstBinaryVector := dstVector.Data.(*schemapb.VectorField_BinaryVector)
dstBinaryVector.BinaryVector = append(dstBinaryVector.BinaryVector, srcVector.BinaryVector...)
}
case *schemapb.VectorField_Float16Vector:
if dstVector.GetFloat16Vector() == nil {
dstVector.Data = &schemapb.VectorField_Float16Vector{
Float16Vector: srcVector.Float16Vector,
}
} else {
dstFloat16Vector := dstVector.Data.(*schemapb.VectorField_Float16Vector)
dstFloat16Vector.Float16Vector = append(dstFloat16Vector.Float16Vector, srcVector.Float16Vector...)
}
case *schemapb.VectorField_Bfloat16Vector:
if dstVector.GetBfloat16Vector() == nil {
dstVector.Data = &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: srcVector.Bfloat16Vector,
}
} else {
dstBfloat16Vector := dstVector.Data.(*schemapb.VectorField_Bfloat16Vector)
dstBfloat16Vector.Bfloat16Vector = append(dstBfloat16Vector.Bfloat16Vector, srcVector.Bfloat16Vector...)
}
case *schemapb.VectorField_FloatVector:
if dstVector.GetFloatVector() == nil {
dstVector.Data = &schemapb.VectorField_FloatVector{

View File

@ -984,6 +984,36 @@ func TestDeleteFieldData(t *testing.T) {
assert.Equal(t, tmpSparseFloatVector, result2[SparseFloatVectorFieldID-common.StartOfUserFieldID].GetVectors().GetSparseFloatVector())
}
func TestEstimateEntitySize(t *testing.T) {
samples := []*schemapb.FieldData{
{
FieldId: 111,
FieldName: "float16_vector",
Type: schemapb.DataType_Float16Vector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 64,
Data: &schemapb.VectorField_Float16Vector{},
},
},
},
{
FieldId: 112,
FieldName: "bfloat16_vector",
Type: schemapb.DataType_BFloat16Vector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_Bfloat16Vector{},
},
},
},
}
size, error := EstimateEntitySize(samples, int(0))
assert.NoError(t, error)
assert.True(t, size == 384)
}
func TestGetPrimaryFieldSchema(t *testing.T) {
int64Field := &schemapb.FieldSchema{
FieldID: 1,
@ -1461,6 +1491,8 @@ func TestMergeFieldData(t *testing.T) {
},
FieldId: 106,
},
genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("12345678"), 4),
genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("12345678"), 4),
}
srcFields := []*schemapb.FieldData{
@ -1520,6 +1552,8 @@ func TestMergeFieldData(t *testing.T) {
},
FieldId: 106,
},
genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("abcdefgh"), 4),
genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("ABCDEFGH"), 4),
}
err := MergeFieldData(dstFields, srcFields)
@ -1552,6 +1586,8 @@ func TestMergeFieldData(t *testing.T) {
Dim: 2301,
Contents: sparseFloatRows,
}, dstFields[6].GetVectors().GetSparseFloatVector())
assert.Equal(t, []byte("12345678abcdefgh"), dstFields[7].GetVectors().GetFloat16Vector())
assert.Equal(t, []byte("12345678ABCDEFGH"), dstFields[8].GetVectors().GetBfloat16Vector())
})
t.Run("merge with nil", func(t *testing.T) {
@ -1584,6 +1620,8 @@ func TestMergeFieldData(t *testing.T) {
},
FieldId: 104,
},
genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("12345678"), 4),
genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("12345678"), 4),
}
dstFields := []*schemapb.FieldData{
@ -1592,6 +1630,8 @@ func TestMergeFieldData(t *testing.T) {
{Type: schemapb.DataType_JSON, FieldName: "json", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_JsonData{}}}, FieldId: 102},
{Type: schemapb.DataType_Array, FieldName: "array", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_ArrayData{}}}, FieldId: 103},
{Type: schemapb.DataType_SparseFloatVector, FieldName: "sparseFloat", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_SparseFloatVector{}}}, FieldId: 104},
{Type: schemapb.DataType_Float16Vector, FieldName: "float16_vector", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_Float16Vector{}}}, FieldId: 111},
{Type: schemapb.DataType_BFloat16Vector, FieldName: "bfloat16_vector", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_Bfloat16Vector{}}}, FieldId: 112},
}
err := MergeFieldData(dstFields, srcFields)
@ -1615,6 +1655,8 @@ func TestMergeFieldData(t *testing.T) {
Dim: 521,
Contents: sparseFloatRows[:3],
}, dstFields[4].GetVectors().GetSparseFloatVector())
assert.Equal(t, []byte("12345678"), dstFields[5].GetVectors().GetFloat16Vector())
assert.Equal(t, []byte("12345678"), dstFields[6].GetVectors().GetBfloat16Vector())
})
t.Run("error case", func(t *testing.T) {
@ -1903,6 +1945,32 @@ func (s *FieldDataSuite) TestPrepareFieldData() {
s.EqualValues(topK*128*2, cap(field.GetVectors().GetFloat16Vector()))
})
s.Run("bfloat16_vector", func() {
samples := []*schemapb.FieldData{
{
FieldId: fieldID,
FieldName: fieldName,
Type: schemapb.DataType_BFloat16Vector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_Bfloat16Vector{},
},
},
},
}
fields := PrepareResultFieldData(samples, topK)
s.Require().Len(fields, 1)
field := fields[0]
s.Equal(fieldID, field.GetFieldId())
s.Equal(fieldName, field.GetFieldName())
s.Equal(schemapb.DataType_BFloat16Vector, field.GetType())
s.EqualValues(128, field.GetVectors().GetDim())
s.EqualValues(topK*128*2, cap(field.GetVectors().GetBfloat16Vector()))
})
s.Run("binary_vector", func() {
samples := []*schemapb.FieldData{
{