milvus/internal/proxy/task_calc_distance_test.go

491 lines
12 KiB
Go

package proxy
import (
"context"
"errors"
"testing"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/stretchr/testify/assert"
)
func TestCalcDistanceTask_arrangeVectorsByStrID(t *testing.T) {
task := &calcDistanceTask{}
inputIds := make([]string, 0)
inputIds = append(inputIds, "c")
inputIds = append(inputIds, "b")
inputIds = append(inputIds, "a")
sequence := make(map[string]int)
sequence["a"] = 0
sequence["b"] = 1
sequence["c"] = 2
dim := 16
// float vector
floatValue := make([]float32, 0)
for i := 0; i < dim*3; i++ {
floatValue = append(floatValue, float32(i))
}
retrievedVectors := &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: floatValue,
},
},
}
result, err := task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
assert.Nil(t, err)
floatResult := result.GetFloatVector().GetData()
for i := 0; i < 3; i++ {
for j := 0; j < dim; j++ {
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
}
}
// binary vector
binaryValue := make([]byte, 0)
for i := 0; i < 3*dim/8; i++ {
binaryValue = append(binaryValue, byte(i))
}
retrievedVectors = &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
}
result, err = task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
assert.Nil(t, err)
binaryResult := result.GetBinaryVector()
numBytes := dim / 8
for i := 0; i < 3; i++ {
for j := 0; j < numBytes; j++ {
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
}
}
}
func TestCalcDistanceTask_arrangeVectorsByIntID(t *testing.T) {
task := &calcDistanceTask{}
inputIds := make([]int64, 0)
inputIds = append(inputIds, 2)
inputIds = append(inputIds, 0)
inputIds = append(inputIds, 1)
sequence := make(map[int64]int)
sequence[0] = 0
sequence[1] = 1
sequence[2] = 2
dim := 16
// float vector
floatValue := make([]float32, 0)
for i := 0; i < dim*3; i++ {
floatValue = append(floatValue, float32(i))
}
retrievedVectors := &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: floatValue,
},
},
}
result, err := task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
assert.Nil(t, err)
floatResult := result.GetFloatVector().GetData()
for i := 0; i < 3; i++ {
for j := 0; j < dim; j++ {
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
}
}
// binary vector
binaryValue := make([]byte, 0)
for i := 0; i < dim*3; i++ {
binaryValue = append(binaryValue, byte(i))
}
retrievedVectors = &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
}
result, err = task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
assert.Nil(t, err)
binaryResult := result.GetBinaryVector()
numBytes := dim / 8
for i := 0; i < 3; i++ {
for j := 0; j < numBytes; j++ {
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
}
}
}
func TestCalcDistanceTask_ExecuteFloat(t *testing.T) {
ctx := context.Background()
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
return nil, errors.New("unexpected error")
}
task := &calcDistanceTask{
traceID: "dummy",
queryFunc: queryFunc,
}
request := &milvuspb.CalcDistanceRequest{
OpLeft: nil,
OpRight: nil,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "L2"},
},
}
// left-op empty
calcResult, err := task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{},
},
},
OpRight: nil,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "L2"},
},
}
// left-op query error
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
fieldIds := make([]int64, 0)
fieldIds = append(fieldIds, 2)
fieldIds = append(fieldIds, 0)
fieldIds = append(fieldIds, 1)
dim := 8
floatValue := make([]float32, 0)
for i := 0; i < dim*3; i++ {
floatValue = append(floatValue, float32(i))
}
queryFunc = func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
if ids == nil {
return &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected",
},
}, nil
}
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "id",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
{
Type: schemapb.DataType_FloatVector,
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: floatValue,
},
},
},
},
},
},
}, nil
}
task.queryFunc = queryFunc
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
idArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{
FieldName: "vec",
IdArray: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
}
request = &milvuspb.CalcDistanceRequest{
OpLeft: idArray,
OpRight: idArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "L2"},
},
}
// success
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
// right-op query error
request.OpRight = nil
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request.OpRight = &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{
FieldName: "kkk",
IdArray: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
}
// right-op arrange error
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request.OpRight = &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: 5,
},
},
}
// different dimension
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request.OpRight = &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: make([]float32, 0),
},
},
},
},
}
// calcdistance return error
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
}
func TestCalcDistanceTask_ExecuteBinary(t *testing.T) {
ctx := context.Background()
fieldIds := make([]int64, 0)
fieldIds = append(fieldIds, 2)
fieldIds = append(fieldIds, 0)
fieldIds = append(fieldIds, 1)
dim := 16
binaryValue := make([]byte, 0)
for i := 0; i < 3*dim/8; i++ {
binaryValue = append(binaryValue, byte(i))
}
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
if ids == nil {
return &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected",
},
}, nil
}
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "id",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
{
Type: schemapb.DataType_FloatVector,
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
},
},
},
},
}, nil
}
idArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{
FieldName: "vec",
IdArray: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
}
request := &milvuspb.CalcDistanceRequest{
OpLeft: idArray,
OpRight: idArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
task := &calcDistanceTask{
traceID: "dummy",
queryFunc: queryFunc,
}
// success
calcResult, err := task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
floatArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{},
},
},
}
binaryArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
},
},
}
request = &milvuspb.CalcDistanceRequest{
OpLeft: floatArray,
OpRight: binaryArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
// float vs binary
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: binaryArray,
OpRight: binaryArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
// hamming
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: binaryArray,
OpRight: binaryArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "TANIMOTO"},
},
}
// tanimoto
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: binaryArray,
OpRight: &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: make([]byte, 0),
},
},
},
},
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
// hamming error
calcResult, err = task.Execute(ctx, request)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
}