mirror of https://github.com/milvus-io/milvus.git
parent
df275471e1
commit
3928da6493
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
|
@ -38,10 +39,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||
"github.com/milvus-io/milvus/internal/util/distance"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/logutil"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
|
@ -3037,16 +3035,6 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||
Status: unhealthyStatus(),
|
||||
}, nil
|
||||
}
|
||||
param, _ := funcutil.GetAttrByKeyFromRepeatedKV("metric", request.GetParams())
|
||||
metric, err := distance.ValidateMetricType(param)
|
||||
if err != nil {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-CalcDistance")
|
||||
defer sp.Finish()
|
||||
|
@ -3080,15 +3068,15 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||
queryShardPolicy: roundRobinPolicy,
|
||||
}
|
||||
|
||||
items := []zapcore.Field{
|
||||
zap.String("collection", queryRequest.CollectionName),
|
||||
zap.Any("partitions", queryRequest.PartitionNames),
|
||||
zap.Any("OutputFields", queryRequest.OutputFields),
|
||||
}
|
||||
|
||||
err := node.sched.dqQueue.Enqueue(qt)
|
||||
if err != nil {
|
||||
log.Debug("CalcDistance queryTask failed to enqueue",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", queryRequest.DbName),
|
||||
zap.String("collection", queryRequest.CollectionName),
|
||||
zap.Any("partitions", queryRequest.PartitionNames))
|
||||
log.Error("CalcDistance queryTask failed to enqueue", append(items, zap.Error(err))...)
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -3098,28 +3086,11 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||
}, err
|
||||
}
|
||||
|
||||
log.Debug("CalcDistance queryTask enqueued",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Int64("msgID", qt.Base.MsgID),
|
||||
zap.Uint64("timestamp", qt.Base.Timestamp),
|
||||
zap.String("db", queryRequest.DbName),
|
||||
zap.String("collection", queryRequest.CollectionName),
|
||||
zap.Any("partitions", queryRequest.PartitionNames),
|
||||
zap.Any("OutputFields", queryRequest.OutputFields))
|
||||
log.Debug("CalcDistance queryTask enqueued", items...)
|
||||
|
||||
err = qt.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Debug("CalcDistance queryTask failed to WaitToFinish",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Int64("msgID", qt.Base.MsgID),
|
||||
zap.Uint64("timestamp", qt.Base.Timestamp),
|
||||
zap.String("db", queryRequest.DbName),
|
||||
zap.String("collection", queryRequest.CollectionName),
|
||||
zap.Any("partitions", queryRequest.PartitionNames),
|
||||
zap.Any("OutputFields", queryRequest.OutputFields))
|
||||
log.Error("CalcDistance queryTask failed to WaitToFinish", append(items, zap.Error(err))...)
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -3129,15 +3100,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||
}, err
|
||||
}
|
||||
|
||||
log.Debug("CalcDistance queryTask Done",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Int64("msgID", qt.Base.MsgID),
|
||||
zap.Uint64("timestamp", qt.Base.Timestamp),
|
||||
zap.String("db", queryRequest.DbName),
|
||||
zap.String("collection", queryRequest.CollectionName),
|
||||
zap.Any("partitions", queryRequest.PartitionNames),
|
||||
zap.Any("OutputFields", queryRequest.OutputFields))
|
||||
log.Debug("CalcDistance queryTask Done", items...)
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
Status: qt.result.Status,
|
||||
|
@ -3145,328 +3108,13 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||
}, nil
|
||||
}
|
||||
|
||||
// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
|
||||
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
|
||||
var retrievedIds *schemapb.ScalarField
|
||||
var retrievedVectors *schemapb.VectorField
|
||||
for _, fieldData := range retrievedFields {
|
||||
if fieldData.FieldName == ids.FieldName {
|
||||
retrievedVectors = fieldData.GetVectors()
|
||||
}
|
||||
if fieldData.Type == schemapb.DataType_Int64 {
|
||||
retrievedIds = fieldData.GetScalars()
|
||||
}
|
||||
}
|
||||
|
||||
if retrievedIds == nil || retrievedVectors == nil {
|
||||
return nil, errors.New("failed to fetch vectors")
|
||||
}
|
||||
|
||||
dict := make(map[int64]int)
|
||||
for index, id := range retrievedIds.GetLongData().Data {
|
||||
dict[id] = index
|
||||
}
|
||||
|
||||
inputIds := ids.IdArray.GetIntId().Data
|
||||
if retrievedVectors.GetFloatVector() != nil {
|
||||
floatArr := retrievedVectors.GetFloatVector().Data
|
||||
element := retrievedVectors.GetDim()
|
||||
result := make([]float32, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := dict[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: result,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if retrievedVectors.GetBinaryVector() != nil {
|
||||
binaryArr := retrievedVectors.GetBinaryVector()
|
||||
element := retrievedVectors.GetDim()
|
||||
if element%8 != 0 {
|
||||
element = element + 8 - element%8
|
||||
}
|
||||
|
||||
result := make([]byte, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := dict[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, binaryArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element * 8,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: result,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to fetch vectors")
|
||||
// calcDistanceTask is not a standard task, no need to enqueue
|
||||
task := &calcDistanceTask{
|
||||
traceID: traceID,
|
||||
queryFunc: query,
|
||||
}
|
||||
|
||||
log.Debug("CalcDistance received",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("metric", metric))
|
||||
|
||||
vectorsLeft := request.GetOpLeft().GetDataArray()
|
||||
opLeft := request.GetOpLeft().GetIdArray()
|
||||
if opLeft != nil {
|
||||
log.Debug("OpLeft IdArray not empty, Get vectors by id",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
result, err := query(opLeft)
|
||||
if err != nil {
|
||||
log.Debug("Failed to get left vectors by id",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("OpLeft IdArray not empty, Get vectors by id done",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
|
||||
if err != nil {
|
||||
log.Debug("Failed to re-arrange left vectors",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Re-arrange left vectors done",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if vectorsLeft == nil {
|
||||
msg := "Left vectors array is empty"
|
||||
log.Debug(msg,
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
vectorsRight := request.GetOpRight().GetDataArray()
|
||||
opRight := request.GetOpRight().GetIdArray()
|
||||
if opRight != nil {
|
||||
log.Debug("OpRight IdArray not empty, Get vectors by id",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
result, err := query(opRight)
|
||||
if err != nil {
|
||||
log.Debug("Failed to get right vectors by id",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("OpRight IdArray not empty, Get vectors by id done",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
|
||||
if err != nil {
|
||||
log.Debug("Failed to re-arrange right vectors",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Re-arrange right vectors done",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if vectorsRight == nil {
|
||||
msg := "Right vectors array is empty"
|
||||
log.Debug(msg,
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.Dim != vectorsRight.Dim {
|
||||
msg := "Vectors dimension is not equal"
|
||||
log.Debug(msg,
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
|
||||
distances, err := distance.CalcFloatDistance(vectorsLeft.Dim, vectorsLeft.GetFloatVector().Data, vectorsRight.GetFloatVector().Data, metric)
|
||||
if err != nil {
|
||||
log.Debug("Failed to CalcFloatDistance",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("CalcFloatDistance done",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
||||
FloatDist: &schemapb.FloatArray{
|
||||
Data: distances,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
|
||||
hamming, err := distance.CalcHammingDistance(vectorsLeft.Dim, vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
|
||||
if err != nil {
|
||||
log.Debug("Failed to CalcHammingDistance",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if metric == distance.HAMMING {
|
||||
log.Debug("CalcHammingDistance done",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_IntDist{
|
||||
IntDist: &schemapb.IntArray{
|
||||
Data: hamming,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if metric == distance.TANIMOTO {
|
||||
tanimoto, err := distance.CalcTanimotoCoefficient(vectorsLeft.Dim, hamming)
|
||||
if err != nil {
|
||||
log.Debug("Failed to CalcTanimotoCoefficient",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("CalcTanimotoCoefficient done",
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
||||
FloatDist: &schemapb.FloatArray{
|
||||
Data: tanimoto,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
err = errors.New("unexpected error")
|
||||
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
|
||||
err = errors.New("cannot calculate distance between binary vectors and float vectors")
|
||||
}
|
||||
|
||||
log.Debug("Failed to CalcDistance",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
return task.Execute(ctx, request)
|
||||
}
|
||||
|
||||
// GetDdChannel returns the used channel for dd operations.
|
||||
|
|
|
@ -0,0 +1,434 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/distance"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type calcDistanceTask struct {
|
||||
traceID string
|
||||
queryFunc func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error)
|
||||
}
|
||||
|
||||
func (t *calcDistanceTask) arrangeVectorsByIntID(inputIds []int64, sequence map[int64]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
|
||||
if retrievedVectors.GetFloatVector() != nil {
|
||||
floatArr := retrievedVectors.GetFloatVector().GetData()
|
||||
element := retrievedVectors.GetDim()
|
||||
result := make([]float32, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: result,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if retrievedVectors.GetBinaryVector() != nil {
|
||||
binaryArr := retrievedVectors.GetBinaryVector()
|
||||
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
|
||||
numBytes := singleBitLen / 8
|
||||
|
||||
result := make([]byte, 0, int64(len(inputIds))*numBytes)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: retrievedVectors.GetDim(),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: result,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported vector type")
|
||||
}
|
||||
|
||||
func (t *calcDistanceTask) arrangeVectorsByStrID(inputIds []string, sequence map[string]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
|
||||
if retrievedVectors.GetFloatVector() != nil {
|
||||
floatArr := retrievedVectors.GetFloatVector().GetData()
|
||||
element := retrievedVectors.GetDim()
|
||||
result := make([]float32, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.String("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: result,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if retrievedVectors.GetBinaryVector() != nil {
|
||||
binaryArr := retrievedVectors.GetBinaryVector()
|
||||
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
|
||||
numBytes := singleBitLen / 8
|
||||
|
||||
result := make([]byte, 0, int64(len(inputIds))*numBytes)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.String("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: retrievedVectors.GetDim(),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: result,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported vector type")
|
||||
}
|
||||
|
||||
func (t *calcDistanceTask) Execute(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) {
|
||||
param, _ := funcutil.GetAttrByKeyFromRepeatedKV("metric", request.GetParams())
|
||||
metric, err := distance.ValidateMetricType(param)
|
||||
if err != nil {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
|
||||
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
|
||||
var retrievedIds *schemapb.ScalarField
|
||||
var retrievedVectors *schemapb.VectorField
|
||||
isStringID := true
|
||||
for _, fieldData := range retrievedFields {
|
||||
if fieldData.FieldName == ids.FieldName {
|
||||
retrievedVectors = fieldData.GetVectors()
|
||||
}
|
||||
if fieldData.Type == schemapb.DataType_Int64 ||
|
||||
fieldData.Type == schemapb.DataType_VarChar ||
|
||||
fieldData.Type == schemapb.DataType_String {
|
||||
retrievedIds = fieldData.GetScalars()
|
||||
|
||||
if fieldData.Type == schemapb.DataType_Int64 {
|
||||
isStringID = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if retrievedIds == nil || retrievedVectors == nil {
|
||||
return nil, errors.New("failed to fetch vectors")
|
||||
}
|
||||
|
||||
if isStringID {
|
||||
dict := make(map[string]int)
|
||||
for index, id := range retrievedIds.GetStringData().GetData() {
|
||||
dict[id] = index
|
||||
}
|
||||
|
||||
inputIds := ids.IdArray.GetStrId().GetData()
|
||||
return t.arrangeVectorsByStrID(inputIds, dict, retrievedVectors)
|
||||
}
|
||||
|
||||
dict := make(map[int64]int)
|
||||
for index, id := range retrievedIds.GetLongData().GetData() {
|
||||
dict[id] = index
|
||||
}
|
||||
|
||||
inputIds := ids.IdArray.GetIntId().GetData()
|
||||
return t.arrangeVectorsByIntID(inputIds, dict, retrievedVectors)
|
||||
}
|
||||
|
||||
log.Debug("CalcDistance received",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("metric", metric))
|
||||
|
||||
vectorsLeft := request.GetOpLeft().GetDataArray()
|
||||
opLeft := request.GetOpLeft().GetIdArray()
|
||||
if opLeft != nil {
|
||||
log.Debug("OpLeft IdArray not empty, Get vectors by id",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
result, err := t.queryFunc(opLeft)
|
||||
if err != nil {
|
||||
log.Debug("Failed to get left vectors by id",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("OpLeft IdArray not empty, Get vectors by id done",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
|
||||
if err != nil {
|
||||
log.Debug("Failed to re-arrange left vectors",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Re-arrange left vectors done",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if vectorsLeft == nil {
|
||||
msg := "Left vectors array is empty"
|
||||
log.Debug(msg,
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
vectorsRight := request.GetOpRight().GetDataArray()
|
||||
opRight := request.GetOpRight().GetIdArray()
|
||||
if opRight != nil {
|
||||
log.Debug("OpRight IdArray not empty, Get vectors by id",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
result, err := t.queryFunc(opRight)
|
||||
if err != nil {
|
||||
log.Debug("Failed to get right vectors by id",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("OpRight IdArray not empty, Get vectors by id done",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
|
||||
if err != nil {
|
||||
log.Debug("Failed to re-arrange right vectors",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Re-arrange right vectors done",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if vectorsRight == nil {
|
||||
msg := "Right vectors array is empty"
|
||||
log.Debug(msg,
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetDim() != vectorsRight.GetDim() {
|
||||
msg := "Vectors dimension is not equal"
|
||||
log.Debug(msg,
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
|
||||
distances, err := distance.CalcFloatDistance(vectorsLeft.GetDim(), vectorsLeft.GetFloatVector().GetData(), vectorsRight.GetFloatVector().GetData(), metric)
|
||||
if err != nil {
|
||||
log.Debug("Failed to CalcFloatDistance",
|
||||
zap.Error(err),
|
||||
zap.Int64("leftDim", vectorsLeft.GetDim()),
|
||||
zap.Int("leftLen", len(vectorsLeft.GetFloatVector().GetData())),
|
||||
zap.Int64("rightDim", vectorsRight.GetDim()),
|
||||
zap.Int("rightLen", len(vectorsRight.GetFloatVector().GetData())),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("CalcFloatDistance done",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
||||
FloatDist: &schemapb.FloatArray{
|
||||
Data: distances,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
|
||||
hamming, err := distance.CalcHammingDistance(vectorsLeft.GetDim(), vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
|
||||
if err != nil {
|
||||
log.Debug("Failed to CalcHammingDistance",
|
||||
zap.Error(err),
|
||||
zap.Int64("leftDim", vectorsLeft.GetDim()),
|
||||
zap.Int("leftLen", len(vectorsLeft.GetBinaryVector())),
|
||||
zap.Int64("rightDim", vectorsRight.GetDim()),
|
||||
zap.Int("rightLen", len(vectorsRight.GetBinaryVector())),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if metric == distance.HAMMING {
|
||||
log.Debug("CalcHammingDistance done",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_IntDist{
|
||||
IntDist: &schemapb.IntArray{
|
||||
Data: hamming,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if metric == distance.TANIMOTO {
|
||||
tanimoto, err := distance.CalcTanimotoCoefficient(vectorsLeft.GetDim(), hamming)
|
||||
if err != nil {
|
||||
log.Debug("Failed to CalcTanimotoCoefficient",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("CalcTanimotoCoefficient done",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
||||
FloatDist: &schemapb.FloatArray{
|
||||
Data: tanimoto,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
err = errors.New("unexpected error")
|
||||
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
|
||||
err = errors.New("cannot calculate distance between binary vectors and float vectors")
|
||||
}
|
||||
|
||||
log.Debug("Failed to CalcDistance",
|
||||
zap.Error(err),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,490 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCalcDistanceTask_arrangeVectorsByStrID(t *testing.T) {
|
||||
task := &calcDistanceTask{}
|
||||
|
||||
inputIds := make([]string, 0)
|
||||
inputIds = append(inputIds, "c")
|
||||
inputIds = append(inputIds, "b")
|
||||
inputIds = append(inputIds, "a")
|
||||
|
||||
sequence := make(map[string]int)
|
||||
sequence["a"] = 0
|
||||
sequence["b"] = 1
|
||||
sequence["c"] = 2
|
||||
|
||||
dim := 16
|
||||
|
||||
// float vector
|
||||
floatValue := make([]float32, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
floatValue = append(floatValue, float32(i))
|
||||
}
|
||||
retrievedVectors := &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: floatValue,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
|
||||
assert.Nil(t, err)
|
||||
|
||||
floatResult := result.GetFloatVector().GetData()
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < dim; j++ {
|
||||
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
|
||||
}
|
||||
}
|
||||
|
||||
// binary vector
|
||||
binaryValue := make([]byte, 0)
|
||||
for i := 0; i < 3*dim/8; i++ {
|
||||
binaryValue = append(binaryValue, byte(i))
|
||||
}
|
||||
retrievedVectors = &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
}
|
||||
|
||||
result, err = task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
|
||||
assert.Nil(t, err)
|
||||
|
||||
binaryResult := result.GetBinaryVector()
|
||||
numBytes := dim / 8
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < numBytes; j++ {
|
||||
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalcDistanceTask_arrangeVectorsByIntID(t *testing.T) {
|
||||
task := &calcDistanceTask{}
|
||||
|
||||
inputIds := make([]int64, 0)
|
||||
inputIds = append(inputIds, 2)
|
||||
inputIds = append(inputIds, 0)
|
||||
inputIds = append(inputIds, 1)
|
||||
|
||||
sequence := make(map[int64]int)
|
||||
sequence[0] = 0
|
||||
sequence[1] = 1
|
||||
sequence[2] = 2
|
||||
|
||||
dim := 16
|
||||
|
||||
// float vector
|
||||
floatValue := make([]float32, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
floatValue = append(floatValue, float32(i))
|
||||
}
|
||||
retrievedVectors := &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: floatValue,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
|
||||
assert.Nil(t, err)
|
||||
|
||||
floatResult := result.GetFloatVector().GetData()
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < dim; j++ {
|
||||
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
|
||||
}
|
||||
}
|
||||
|
||||
// binary vector
|
||||
binaryValue := make([]byte, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
binaryValue = append(binaryValue, byte(i))
|
||||
}
|
||||
retrievedVectors = &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
}
|
||||
|
||||
result, err = task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
|
||||
assert.Nil(t, err)
|
||||
|
||||
binaryResult := result.GetBinaryVector()
|
||||
numBytes := dim / 8
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < numBytes; j++ {
|
||||
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalcDistanceTask_ExecuteFloat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
||||
return nil, errors.New("unexpected error")
|
||||
}
|
||||
|
||||
task := &calcDistanceTask{
|
||||
traceID: "dummy",
|
||||
queryFunc: queryFunc,
|
||||
}
|
||||
|
||||
request := &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: nil,
|
||||
OpRight: nil,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "L2"},
|
||||
},
|
||||
}
|
||||
|
||||
// left-op empty
|
||||
calcResult, err := task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{},
|
||||
},
|
||||
},
|
||||
OpRight: nil,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "L2"},
|
||||
},
|
||||
}
|
||||
|
||||
// left-op query error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
fieldIds := make([]int64, 0)
|
||||
fieldIds = append(fieldIds, 2)
|
||||
fieldIds = append(fieldIds, 0)
|
||||
fieldIds = append(fieldIds, 1)
|
||||
|
||||
dim := 8
|
||||
floatValue := make([]float32, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
floatValue = append(floatValue, float32(i))
|
||||
}
|
||||
|
||||
queryFunc = func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
||||
if ids == nil {
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unexpected",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "id",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: floatValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
task.queryFunc = queryFunc
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
idArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{
|
||||
FieldName: "vec",
|
||||
IdArray: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: idArray,
|
||||
OpRight: idArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "L2"},
|
||||
},
|
||||
}
|
||||
|
||||
// success
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
// right-op query error
|
||||
request.OpRight = nil
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request.OpRight = &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{
|
||||
FieldName: "kkk",
|
||||
IdArray: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// right-op arrange error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request.OpRight = &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// different dimension
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request.OpRight = &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: make([]float32, 0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// calcdistance return error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
}
|
||||
|
||||
func TestCalcDistanceTask_ExecuteBinary(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
fieldIds := make([]int64, 0)
|
||||
fieldIds = append(fieldIds, 2)
|
||||
fieldIds = append(fieldIds, 0)
|
||||
fieldIds = append(fieldIds, 1)
|
||||
|
||||
dim := 16
|
||||
binaryValue := make([]byte, 0)
|
||||
for i := 0; i < 3*dim/8; i++ {
|
||||
binaryValue = append(binaryValue, byte(i))
|
||||
}
|
||||
|
||||
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
||||
if ids == nil {
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unexpected",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "id",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
idArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{
|
||||
FieldName: "vec",
|
||||
IdArray: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
request := &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: idArray,
|
||||
OpRight: idArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
task := &calcDistanceTask{
|
||||
traceID: "dummy",
|
||||
queryFunc: queryFunc,
|
||||
}
|
||||
|
||||
// success
|
||||
calcResult, err := task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
floatArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{},
|
||||
},
|
||||
},
|
||||
}
|
||||
binaryArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: floatArray,
|
||||
OpRight: binaryArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
// float vs binary
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: binaryArray,
|
||||
OpRight: binaryArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
// hamming
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: binaryArray,
|
||||
OpRight: binaryArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "TANIMOTO"},
|
||||
},
|
||||
}
|
||||
|
||||
// tanimoto
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: binaryArray,
|
||||
OpRight: &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: make([]byte, 0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
// hamming error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
}
|
Loading…
Reference in New Issue