mirror of https://github.com/milvus-io/milvus.git
fix: fix fp16/bf16 some code missing and add more fp16/bf16 test (#31612)
issue: #31534 Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>pull/31680/head
parent
3d66670619
commit
976928ecd1
1
go.mod
1
go.mod
|
@ -200,6 +200,7 @@ require (
|
|||
github.com/twmb/murmur3 v1.1.3 // indirect
|
||||
github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.2 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -868,6 +868,8 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn
|
|||
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
||||
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
|
||||
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/x448/float16"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
@ -278,6 +279,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
|
|||
fieldArray := genConstantFieldSchema(simpleArrayField)
|
||||
floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField)
|
||||
binVecFieldSchema := genVectorFieldSchema(simpleBinVecField)
|
||||
float16VecFieldSchema := genVectorFieldSchema(simpleFloat16VecField)
|
||||
bfloat16VecFieldSchema := genVectorFieldSchema(simpleBFloat16VecField)
|
||||
var pkFieldSchema *schemapb.FieldSchema
|
||||
|
||||
switch pkType {
|
||||
|
@ -302,6 +305,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
|
|||
binVecFieldSchema,
|
||||
pkFieldSchema,
|
||||
fieldArray,
|
||||
float16VecFieldSchema,
|
||||
bfloat16VecFieldSchema,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -330,7 +335,7 @@ func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema)
|
|||
TypeParams: field.GetTypeParams(),
|
||||
}
|
||||
switch field.GetDataType() {
|
||||
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector:
|
||||
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
|
||||
{
|
||||
index.IndexParams = []*commonpb.KeyValuePair{
|
||||
{Key: common.MetricTypeKey, Value: metric.L2},
|
||||
|
@ -500,21 +505,28 @@ func generateBinaryVectors(numRows, dim int) []byte {
|
|||
}
|
||||
|
||||
func generateFloat16Vectors(numRows, dim int) []byte {
|
||||
total := numRows * dim * 2
|
||||
ret := make([]byte, total)
|
||||
_, err := rand.Read(ret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
total := numRows * dim
|
||||
ret := make([]byte, total*2)
|
||||
for i := 0; i < total; i++ {
|
||||
v := float16.Fromfloat32(rand.Float32()).Bits()
|
||||
binary.LittleEndian.PutUint16(ret[i*2:], v)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func generateBFloat16Vectors(numRows, dim int) []byte {
|
||||
total := numRows * dim * 2
|
||||
ret := make([]byte, total)
|
||||
_, err := rand.Read(ret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
total := numRows * dim
|
||||
ret16 := make([]uint16, 0, total)
|
||||
for i := 0; i < total; i++ {
|
||||
f := rand.Float32()
|
||||
bits := math.Float32bits(f)
|
||||
bits >>= 16
|
||||
bits &= 0x7FFF
|
||||
ret16 = append(ret16, uint16(bits))
|
||||
}
|
||||
ret := make([]byte, len(ret16)*2)
|
||||
for i, value := range ret16 {
|
||||
binary.LittleEndian.PutUint16(ret[i*2:], value)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -1009,6 +1021,10 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
|
|||
dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim))
|
||||
case schemapb.DataType_FloatVector:
|
||||
dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim))
|
||||
case schemapb.DataType_Float16Vector:
|
||||
dataset = indexcgowrapper.GenFloat16VecDataset(generateFloat16Vectors(msgLength, defaultDim))
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
dataset = indexcgowrapper.GenBFloat16VecDataset(generateBFloat16Vectors(msgLength, defaultDim))
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
data := testutils.GenerateSparseFloatVectors(msgLength)
|
||||
dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{
|
||||
|
@ -1260,7 +1276,7 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
|
|||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
var fieldID int64
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector || f.DataType == schemapb.DataType_Float16Vector || f.DataType == schemapb.DataType_BFloat16Vector {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
fieldID = f.FieldID
|
||||
for _, p := range f.IndexParams {
|
||||
|
|
|
@ -93,6 +93,28 @@ func getPKsFromRowBasedInsertMsg(msg *msgstream.InsertMsg, schema *schemapb.Coll
|
|||
break
|
||||
}
|
||||
}
|
||||
case schemapb.DataType_Float16Vector:
|
||||
for _, t := range field.TypeParams {
|
||||
if t.Key == common.DimKey {
|
||||
dim, err := strconv.Atoi(t.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
|
||||
}
|
||||
offset += dim * 2
|
||||
break
|
||||
}
|
||||
}
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
for _, t := range field.TypeParams {
|
||||
if t.Key == common.DimKey {
|
||||
dim, err := strconv.Atoi(t.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
|
||||
}
|
||||
offset += dim * 2
|
||||
break
|
||||
}
|
||||
}
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
return nil, fmt.Errorf("SparseFloatVector not support in row based message")
|
||||
}
|
||||
|
@ -280,6 +302,10 @@ func fillFieldData(ctx context.Context, vcm storage.ChunkManager, dataPath strin
|
|||
return fillBinVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
|
||||
case schemapb.DataType_FloatVector:
|
||||
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
|
||||
case schemapb.DataType_Float16Vector:
|
||||
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
return fillSparseFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
|
||||
case schemapb.DataType_Bool:
|
||||
|
|
|
@ -307,6 +307,37 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
|||
}
|
||||
fmt.Println()
|
||||
}
|
||||
case schemapb.DataType_Float16Vector:
|
||||
val, dim, err := reader.GetFloat16VectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dim = dim * 2
|
||||
length := len(val) / dim
|
||||
for i := 0; i < length; i++ {
|
||||
fmt.Printf("\t\t%d :", i)
|
||||
for j := 0; j < dim; j++ {
|
||||
idx := i*dim + j
|
||||
fmt.Printf(" %02x", val[idx])
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
val, dim, err := reader.GetBFloat16VectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dim = dim * 2
|
||||
length := len(val) / dim
|
||||
for i := 0; i < length; i++ {
|
||||
fmt.Printf("\t\t%d :", i)
|
||||
for j := 0; j < dim; j++ {
|
||||
idx := i*dim + j
|
||||
fmt.Printf(" %02x", val[idx])
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
case schemapb.DataType_FloatVector:
|
||||
val, dim, err := reader.GetFloatVectorFromPayload()
|
||||
if err != nil {
|
||||
|
|
|
@ -184,6 +184,20 @@ func TestPrintBinlogFiles(t *testing.T) {
|
|||
Description: "description_12",
|
||||
DataType: schemapb.DataType_JSON,
|
||||
},
|
||||
{
|
||||
FieldID: 111,
|
||||
Name: "field_bfloat16_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "description_13",
|
||||
DataType: schemapb.DataType_BFloat16Vector,
|
||||
},
|
||||
{
|
||||
FieldID: 112,
|
||||
Name: "field_float16_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "description_14",
|
||||
DataType: schemapb.DataType_Float16Vector,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -234,6 +248,14 @@ func TestPrintBinlogFiles(t *testing.T) {
|
|||
[]byte(`{"key":"hello"}`),
|
||||
},
|
||||
},
|
||||
111: &BFloat16VectorFieldData{
|
||||
Data: []byte("12345678"),
|
||||
Dim: 4,
|
||||
},
|
||||
112: &Float16VectorFieldData{
|
||||
Data: []byte("12345678"),
|
||||
Dim: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -283,6 +305,14 @@ func TestPrintBinlogFiles(t *testing.T) {
|
|||
[]byte(`{"key":"world"}`),
|
||||
},
|
||||
},
|
||||
111: &BFloat16VectorFieldData{
|
||||
Data: []byte("abcdefgh"),
|
||||
Dim: 4,
|
||||
},
|
||||
112: &Float16VectorFieldData{
|
||||
Data: []byte("abcdefgh"),
|
||||
Dim: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
firstBlobs, err := insertCodec.Serialize(1, 1, insertDataFirst)
|
||||
|
|
|
@ -1201,6 +1201,32 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
|
|||
},
|
||||
},
|
||||
}
|
||||
case *Float16VectorFieldData:
|
||||
fieldData = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float16Vector,
|
||||
FieldId: fieldID,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Float16Vector{
|
||||
Float16Vector: rawData.Data,
|
||||
},
|
||||
Dim: int64(rawData.Dim),
|
||||
},
|
||||
},
|
||||
}
|
||||
case *BFloat16VectorFieldData:
|
||||
fieldData = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BFloat16Vector,
|
||||
FieldId: fieldID,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Bfloat16Vector{
|
||||
Bfloat16Vector: rawData.Data,
|
||||
},
|
||||
Dim: int64(rawData.Dim),
|
||||
},
|
||||
},
|
||||
}
|
||||
case *SparseFloatVectorFieldData:
|
||||
fieldData = &schemapb.FieldData{
|
||||
Type: schemapb.DataType_SparseFloatVector,
|
||||
|
|
|
@ -993,6 +993,15 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRowBasedTransferInsertMsgToInsertRecord(t *testing.T) {
|
||||
numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim := 10, 8, 8, 8, 8
|
||||
schema, _, _ := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim, bf16VecDim, false)
|
||||
msg, _, _ := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim)
|
||||
|
||||
_, err := TransferInsertMsgToInsertRecord(schema, msg)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) {
|
||||
msg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
|
|
|
@ -89,6 +89,15 @@ func ConvertToArrowSchema(fields []*schemapb.FieldSchema) (*arrow.Schema, error)
|
|||
Name: field.Name,
|
||||
Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
|
||||
})
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
dim, err := storage.GetDimFromParams(field.TypeParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arrowFields = append(arrowFields, arrow.Field{
|
||||
Name: field.Name,
|
||||
Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
|
||||
})
|
||||
default:
|
||||
return nil, merr.WrapErrParameterInvalidMsg("unknown type %v", field.DataType.String())
|
||||
}
|
||||
|
|
|
@ -41,9 +41,33 @@ func TestConvertArrowSchema(t *testing.T) {
|
|||
{FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
|
||||
{FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON},
|
||||
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
}
|
||||
|
||||
schema, err := ConvertToArrowSchema(fieldSchemas)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(fieldSchemas), len(schema.Fields()))
|
||||
}
|
||||
|
||||
func TestConvertArrowSchemaWithoutDim(t *testing.T) {
|
||||
fieldSchemas := []*schemapb.FieldSchema{
|
||||
{FieldID: 1, Name: "field0", DataType: schemapb.DataType_Bool},
|
||||
{FieldID: 2, Name: "field1", DataType: schemapb.DataType_Int8},
|
||||
{FieldID: 3, Name: "field2", DataType: schemapb.DataType_Int16},
|
||||
{FieldID: 4, Name: "field3", DataType: schemapb.DataType_Int32},
|
||||
{FieldID: 5, Name: "field4", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 6, Name: "field5", DataType: schemapb.DataType_Float},
|
||||
{FieldID: 7, Name: "field6", DataType: schemapb.DataType_Double},
|
||||
{FieldID: 8, Name: "field7", DataType: schemapb.DataType_String},
|
||||
{FieldID: 9, Name: "field8", DataType: schemapb.DataType_VarChar},
|
||||
{FieldID: 10, Name: "field9", DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 11, Name: "field10", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
|
||||
{FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON},
|
||||
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{}},
|
||||
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{}},
|
||||
}
|
||||
|
||||
_, err := ConvertToArrowSchema(fieldSchemas)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
@ -108,6 +108,18 @@ func BuildRecord(b *array.RecordBuilder, data *storage.InsertData, fields []*sch
|
|||
byteLength := dim * 2
|
||||
length := len(data) / byteLength
|
||||
|
||||
builder.Reserve(length)
|
||||
for i := 0; i < length; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
}
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
vecData := data.Data[field.FieldID].(*storage.BFloat16VectorFieldData)
|
||||
builder := fBuilder.(*array.FixedSizeBinaryBuilder)
|
||||
dim := vecData.Dim
|
||||
data := vecData.Data
|
||||
byteLength := dim * 2
|
||||
length := len(data) / byteLength
|
||||
|
||||
builder.Reserve(length)
|
||||
for i := 0; i < length; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
|
|
|
@ -146,7 +146,7 @@ func CheckCtxValid(ctx context.Context) bool {
|
|||
func GetVecFieldIDs(schema *schemapb.CollectionSchema) []int64 {
|
||||
var vecFieldIDs []int64
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_SparseFloatVector {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_BFloat16Vector || field.DataType == schemapb.DataType_SparseFloatVector {
|
||||
vecFieldIDs = append(vecFieldIDs, field.FieldID)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -241,6 +241,10 @@ func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, e
|
|||
res += int(fs.GetVectors().GetDim())
|
||||
case schemapb.DataType_FloatVector:
|
||||
res += int(fs.GetVectors().GetDim() * 4)
|
||||
case schemapb.DataType_Float16Vector:
|
||||
res += int(fs.GetVectors().GetDim() * 2)
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
res += int(fs.GetVectors().GetDim() * 2)
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
vec := fs.GetVectors().GetSparseFloatVector()
|
||||
// counting only the size of the vector data, ignoring other
|
||||
|
@ -527,6 +531,10 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap
|
|||
vectors.Vectors.Data = &schemapb.VectorField_Float16Vector{
|
||||
Float16Vector: make([]byte, 0, topK*dim*2),
|
||||
}
|
||||
case *schemapb.VectorField_Bfloat16Vector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_Bfloat16Vector{
|
||||
Bfloat16Vector: make([]byte, 0, topK*dim*2),
|
||||
}
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: make([]byte, 0, topK*dim/8),
|
||||
|
@ -957,6 +965,24 @@ func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error
|
|||
dstBinaryVector := dstVector.Data.(*schemapb.VectorField_BinaryVector)
|
||||
dstBinaryVector.BinaryVector = append(dstBinaryVector.BinaryVector, srcVector.BinaryVector...)
|
||||
}
|
||||
case *schemapb.VectorField_Float16Vector:
|
||||
if dstVector.GetFloat16Vector() == nil {
|
||||
dstVector.Data = &schemapb.VectorField_Float16Vector{
|
||||
Float16Vector: srcVector.Float16Vector,
|
||||
}
|
||||
} else {
|
||||
dstFloat16Vector := dstVector.Data.(*schemapb.VectorField_Float16Vector)
|
||||
dstFloat16Vector.Float16Vector = append(dstFloat16Vector.Float16Vector, srcVector.Float16Vector...)
|
||||
}
|
||||
case *schemapb.VectorField_Bfloat16Vector:
|
||||
if dstVector.GetBfloat16Vector() == nil {
|
||||
dstVector.Data = &schemapb.VectorField_Bfloat16Vector{
|
||||
Bfloat16Vector: srcVector.Bfloat16Vector,
|
||||
}
|
||||
} else {
|
||||
dstBfloat16Vector := dstVector.Data.(*schemapb.VectorField_Bfloat16Vector)
|
||||
dstBfloat16Vector.Bfloat16Vector = append(dstBfloat16Vector.Bfloat16Vector, srcVector.Bfloat16Vector...)
|
||||
}
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
if dstVector.GetFloatVector() == nil {
|
||||
dstVector.Data = &schemapb.VectorField_FloatVector{
|
||||
|
|
|
@ -984,6 +984,36 @@ func TestDeleteFieldData(t *testing.T) {
|
|||
assert.Equal(t, tmpSparseFloatVector, result2[SparseFloatVectorFieldID-common.StartOfUserFieldID].GetVectors().GetSparseFloatVector())
|
||||
}
|
||||
|
||||
func TestEstimateEntitySize(t *testing.T) {
|
||||
samples := []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 111,
|
||||
FieldName: "float16_vector",
|
||||
Type: schemapb.DataType_Float16Vector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 64,
|
||||
Data: &schemapb.VectorField_Float16Vector{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldId: 112,
|
||||
FieldName: "bfloat16_vector",
|
||||
Type: schemapb.DataType_BFloat16Vector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_Bfloat16Vector{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
size, error := EstimateEntitySize(samples, int(0))
|
||||
assert.NoError(t, error)
|
||||
assert.True(t, size == 384)
|
||||
}
|
||||
|
||||
func TestGetPrimaryFieldSchema(t *testing.T) {
|
||||
int64Field := &schemapb.FieldSchema{
|
||||
FieldID: 1,
|
||||
|
@ -1461,6 +1491,8 @@ func TestMergeFieldData(t *testing.T) {
|
|||
},
|
||||
FieldId: 106,
|
||||
},
|
||||
genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("12345678"), 4),
|
||||
genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("12345678"), 4),
|
||||
}
|
||||
|
||||
srcFields := []*schemapb.FieldData{
|
||||
|
@ -1520,6 +1552,8 @@ func TestMergeFieldData(t *testing.T) {
|
|||
},
|
||||
FieldId: 106,
|
||||
},
|
||||
genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("abcdefgh"), 4),
|
||||
genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("ABCDEFGH"), 4),
|
||||
}
|
||||
|
||||
err := MergeFieldData(dstFields, srcFields)
|
||||
|
@ -1552,6 +1586,8 @@ func TestMergeFieldData(t *testing.T) {
|
|||
Dim: 2301,
|
||||
Contents: sparseFloatRows,
|
||||
}, dstFields[6].GetVectors().GetSparseFloatVector())
|
||||
assert.Equal(t, []byte("12345678abcdefgh"), dstFields[7].GetVectors().GetFloat16Vector())
|
||||
assert.Equal(t, []byte("12345678ABCDEFGH"), dstFields[8].GetVectors().GetBfloat16Vector())
|
||||
})
|
||||
|
||||
t.Run("merge with nil", func(t *testing.T) {
|
||||
|
@ -1584,6 +1620,8 @@ func TestMergeFieldData(t *testing.T) {
|
|||
},
|
||||
FieldId: 104,
|
||||
},
|
||||
genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("12345678"), 4),
|
||||
genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("12345678"), 4),
|
||||
}
|
||||
|
||||
dstFields := []*schemapb.FieldData{
|
||||
|
@ -1592,6 +1630,8 @@ func TestMergeFieldData(t *testing.T) {
|
|||
{Type: schemapb.DataType_JSON, FieldName: "json", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_JsonData{}}}, FieldId: 102},
|
||||
{Type: schemapb.DataType_Array, FieldName: "array", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_ArrayData{}}}, FieldId: 103},
|
||||
{Type: schemapb.DataType_SparseFloatVector, FieldName: "sparseFloat", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_SparseFloatVector{}}}, FieldId: 104},
|
||||
{Type: schemapb.DataType_Float16Vector, FieldName: "float16_vector", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_Float16Vector{}}}, FieldId: 111},
|
||||
{Type: schemapb.DataType_BFloat16Vector, FieldName: "bfloat16_vector", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_Bfloat16Vector{}}}, FieldId: 112},
|
||||
}
|
||||
|
||||
err := MergeFieldData(dstFields, srcFields)
|
||||
|
@ -1615,6 +1655,8 @@ func TestMergeFieldData(t *testing.T) {
|
|||
Dim: 521,
|
||||
Contents: sparseFloatRows[:3],
|
||||
}, dstFields[4].GetVectors().GetSparseFloatVector())
|
||||
assert.Equal(t, []byte("12345678"), dstFields[5].GetVectors().GetFloat16Vector())
|
||||
assert.Equal(t, []byte("12345678"), dstFields[6].GetVectors().GetBfloat16Vector())
|
||||
})
|
||||
|
||||
t.Run("error case", func(t *testing.T) {
|
||||
|
@ -1903,6 +1945,32 @@ func (s *FieldDataSuite) TestPrepareFieldData() {
|
|||
s.EqualValues(topK*128*2, cap(field.GetVectors().GetFloat16Vector()))
|
||||
})
|
||||
|
||||
s.Run("bfloat16_vector", func() {
|
||||
samples := []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: fieldID,
|
||||
FieldName: fieldName,
|
||||
Type: schemapb.DataType_BFloat16Vector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_Bfloat16Vector{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fields := PrepareResultFieldData(samples, topK)
|
||||
s.Require().Len(fields, 1)
|
||||
field := fields[0]
|
||||
s.Equal(fieldID, field.GetFieldId())
|
||||
s.Equal(fieldName, field.GetFieldName())
|
||||
s.Equal(schemapb.DataType_BFloat16Vector, field.GetType())
|
||||
|
||||
s.EqualValues(128, field.GetVectors().GetDim())
|
||||
s.EqualValues(topK*128*2, cap(field.GetVectors().GetBfloat16Vector()))
|
||||
})
|
||||
|
||||
s.Run("binary_vector", func() {
|
||||
samples := []*schemapb.FieldData{
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue