mirror of https://github.com/milvus-io/milvus.git
parent
9a61c0291b
commit
73512c72fd
|
@ -3150,88 +3150,12 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR
|
|||
|
||||
// CalcDistance calculates the distances between vectors.
|
||||
func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) {
|
||||
if !node.checkHealthy() {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: unhealthyStatus(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CalcDistance")
|
||||
defer sp.End()
|
||||
|
||||
query := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
||||
outputFields := []string{ids.FieldName}
|
||||
|
||||
queryRequest := &milvuspb.QueryRequest{
|
||||
DbName: "",
|
||||
CollectionName: ids.CollectionName,
|
||||
PartitionNames: ids.PartitionNames,
|
||||
OutputFields: outputFields,
|
||||
}
|
||||
|
||||
qt := &queryTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
},
|
||||
request: queryRequest,
|
||||
qc: node.queryCoord,
|
||||
ids: ids.IdArray,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("collection", queryRequest.CollectionName),
|
||||
zap.Any("partitions", queryRequest.PartitionNames),
|
||||
zap.Any("OutputFields", queryRequest.OutputFields))
|
||||
|
||||
err := node.sched.dqQueue.Enqueue(qt)
|
||||
if err != nil {
|
||||
log.Error("CalcDistance queryTask failed to enqueue",
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
|
||||
log.Debug("CalcDistance queryTask enqueued")
|
||||
|
||||
err = qt.WaitToFinish()
|
||||
if err != nil {
|
||||
log.Error("CalcDistance queryTask failed to WaitToFinish",
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, err
|
||||
}
|
||||
|
||||
log.Debug("CalcDistance queryTask Done")
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
Status: qt.result.Status,
|
||||
FieldsData: qt.result.FieldsData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// calcDistanceTask is not a standard task, no need to enqueue
|
||||
task := &calcDistanceTask{
|
||||
traceID: sp.SpanContext().TraceID().String(),
|
||||
queryFunc: query,
|
||||
}
|
||||
|
||||
return task.Execute(ctx, request)
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "interface obsolete",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FlushAll notifies Proxy to flush all collection's DML messages.
|
||||
|
|
|
@ -63,9 +63,9 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/tracer"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
@ -639,7 +639,7 @@ func TestProxy(t *testing.T) {
|
|||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
},
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
|
@ -1543,7 +1543,7 @@ func TestProxy(t *testing.T) {
|
|||
Params: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
|
|
@ -1,412 +0,0 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type calcDistanceTask struct {
|
||||
traceID string
|
||||
queryFunc func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error)
|
||||
}
|
||||
|
||||
func (t *calcDistanceTask) arrangeVectorsByIntID(inputIds []int64, sequence map[int64]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
|
||||
if retrievedVectors.GetFloatVector() != nil {
|
||||
floatArr := retrievedVectors.GetFloatVector().GetData()
|
||||
element := retrievedVectors.GetDim()
|
||||
result := make([]float32, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: result,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if retrievedVectors.GetBinaryVector() != nil {
|
||||
binaryArr := retrievedVectors.GetBinaryVector()
|
||||
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
|
||||
numBytes := singleBitLen / 8
|
||||
|
||||
result := make([]byte, 0, int64(len(inputIds))*numBytes)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: retrievedVectors.GetDim(),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: result,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported vector type")
|
||||
}
|
||||
|
||||
func (t *calcDistanceTask) arrangeVectorsByStrID(inputIds []string, sequence map[string]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
|
||||
if retrievedVectors.GetFloatVector() != nil {
|
||||
floatArr := retrievedVectors.GetFloatVector().GetData()
|
||||
element := retrievedVectors.GetDim()
|
||||
result := make([]float32, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.String("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: result,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if retrievedVectors.GetBinaryVector() != nil {
|
||||
binaryArr := retrievedVectors.GetBinaryVector()
|
||||
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
|
||||
numBytes := singleBitLen / 8
|
||||
|
||||
result := make([]byte, 0, int64(len(inputIds))*numBytes)
|
||||
for _, id := range inputIds {
|
||||
index, ok := sequence[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.String("id", id))
|
||||
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: retrievedVectors.GetDim(),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: result,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported vector type")
|
||||
}
|
||||
|
||||
func (t *calcDistanceTask) Execute(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) {
|
||||
param, _ := funcutil.GetAttrByKeyFromRepeatedKV("metric", request.GetParams())
|
||||
metric, err := distance.ValidateMetricType(param)
|
||||
if err != nil {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
|
||||
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
|
||||
var retrievedIds *schemapb.ScalarField
|
||||
var retrievedVectors *schemapb.VectorField
|
||||
isStringID := true
|
||||
for _, fieldData := range retrievedFields {
|
||||
if fieldData.FieldName == ids.FieldName {
|
||||
retrievedVectors = fieldData.GetVectors()
|
||||
}
|
||||
if fieldData.Type == schemapb.DataType_Int64 ||
|
||||
fieldData.Type == schemapb.DataType_VarChar ||
|
||||
fieldData.Type == schemapb.DataType_String {
|
||||
retrievedIds = fieldData.GetScalars()
|
||||
|
||||
if fieldData.Type == schemapb.DataType_Int64 {
|
||||
isStringID = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if retrievedIds == nil || retrievedVectors == nil {
|
||||
return nil, errors.New("failed to fetch vectors")
|
||||
}
|
||||
|
||||
if isStringID {
|
||||
dict := make(map[string]int)
|
||||
for index, id := range retrievedIds.GetStringData().GetData() {
|
||||
dict[id] = index
|
||||
}
|
||||
|
||||
inputIds := ids.IdArray.GetStrId().GetData()
|
||||
return t.arrangeVectorsByStrID(inputIds, dict, retrievedVectors)
|
||||
}
|
||||
|
||||
dict := make(map[int64]int)
|
||||
for index, id := range retrievedIds.GetLongData().GetData() {
|
||||
dict[id] = index
|
||||
}
|
||||
|
||||
inputIds := ids.IdArray.GetIntId().GetData()
|
||||
return t.arrangeVectorsByIntID(inputIds, dict, retrievedVectors)
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx)
|
||||
log.Debug("CalcDistance received",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("metric", metric))
|
||||
|
||||
vectorsLeft := request.GetOpLeft().GetDataArray()
|
||||
opLeft := request.GetOpLeft().GetIdArray()
|
||||
if opLeft != nil {
|
||||
log.Debug("OpLeft IdArray not empty, Get vectors by id", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
result, err := t.queryFunc(opLeft)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get left vectors by id",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("OpLeft IdArray not empty, Get vectors by id done",
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
|
||||
if err != nil {
|
||||
log.Debug("Failed to re-arrange left vectors",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Re-arrange left vectors done",
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if vectorsLeft == nil {
|
||||
msg := "Left vectors array is empty"
|
||||
log.Debug(msg,
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
vectorsRight := request.GetOpRight().GetDataArray()
|
||||
opRight := request.GetOpRight().GetIdArray()
|
||||
if opRight != nil {
|
||||
log.Debug("OpRight IdArray not empty, Get vectors by id",
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
result, err := t.queryFunc(opRight)
|
||||
if err != nil {
|
||||
log.Debug("Failed to get right vectors by id",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("OpRight IdArray not empty, Get vectors by id done",
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
|
||||
if err != nil {
|
||||
log.Debug("Failed to re-arrange right vectors",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("Re-arrange right vectors done",
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if vectorsRight == nil {
|
||||
msg := "Right vectors array is empty"
|
||||
log.Warn(msg, zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetDim() != vectorsRight.GetDim() {
|
||||
msg := "Vectors dimension is not equal"
|
||||
log.Debug(msg, zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: msg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
|
||||
distances, err := distance.CalcFloatDistance(vectorsLeft.GetDim(), vectorsLeft.GetFloatVector().GetData(), vectorsRight.GetFloatVector().GetData(), metric)
|
||||
if err != nil {
|
||||
log.Warn("Failed to CalcFloatDistance",
|
||||
zap.Int64("leftDim", vectorsLeft.GetDim()),
|
||||
zap.Int("leftLen", len(vectorsLeft.GetFloatVector().GetData())),
|
||||
zap.Int64("rightDim", vectorsRight.GetDim()),
|
||||
zap.Int("rightLen", len(vectorsRight.GetFloatVector().GetData())),
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("CalcFloatDistance done",
|
||||
zap.String("traceID", t.traceID),
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
||||
FloatDist: &schemapb.FloatArray{
|
||||
Data: distances,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
|
||||
hamming, err := distance.CalcHammingDistance(vectorsLeft.GetDim(), vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
|
||||
if err != nil {
|
||||
log.Debug("Failed to CalcHammingDistance",
|
||||
zap.Int64("leftDim", vectorsLeft.GetDim()),
|
||||
zap.Int("leftLen", len(vectorsLeft.GetBinaryVector())),
|
||||
zap.Int64("rightDim", vectorsRight.GetDim()),
|
||||
zap.Int("rightLen", len(vectorsRight.GetBinaryVector())),
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if metric == distance.HAMMING {
|
||||
log.Debug("CalcHammingDistance done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_IntDist{
|
||||
IntDist: &schemapb.IntArray{
|
||||
Data: hamming,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if metric == distance.TANIMOTO {
|
||||
tanimoto, err := distance.CalcTanimotoCoefficient(vectorsLeft.GetDim(), hamming)
|
||||
if err != nil {
|
||||
log.Warn("Failed to CalcTanimotoCoefficient",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Debug("CalcTanimotoCoefficient done",
|
||||
zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
|
||||
Array: &milvuspb.CalcDistanceResults_FloatDist{
|
||||
FloatDist: &schemapb.FloatArray{
|
||||
Data: tanimoto,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
err = errors.New("unexpected error")
|
||||
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
|
||||
err = errors.New("cannot calculate distance between binary vectors and float vectors")
|
||||
}
|
||||
|
||||
log.Warn("Failed to CalcDistance",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Error(err))
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
|
@ -1,491 +0,0 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCalcDistanceTask_arrangeVectorsByStrID(t *testing.T) {
|
||||
task := &calcDistanceTask{}
|
||||
|
||||
inputIds := make([]string, 0)
|
||||
inputIds = append(inputIds, "c")
|
||||
inputIds = append(inputIds, "b")
|
||||
inputIds = append(inputIds, "a")
|
||||
|
||||
sequence := make(map[string]int)
|
||||
sequence["a"] = 0
|
||||
sequence["b"] = 1
|
||||
sequence["c"] = 2
|
||||
|
||||
dim := 16
|
||||
|
||||
// float vector
|
||||
floatValue := make([]float32, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
floatValue = append(floatValue, float32(i))
|
||||
}
|
||||
retrievedVectors := &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: floatValue,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
|
||||
assert.NoError(t, err)
|
||||
|
||||
floatResult := result.GetFloatVector().GetData()
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < dim; j++ {
|
||||
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
|
||||
}
|
||||
}
|
||||
|
||||
// binary vector
|
||||
binaryValue := make([]byte, 0)
|
||||
for i := 0; i < 3*dim/8; i++ {
|
||||
binaryValue = append(binaryValue, byte(i))
|
||||
}
|
||||
retrievedVectors = &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
}
|
||||
|
||||
result, err = task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
|
||||
assert.NoError(t, err)
|
||||
|
||||
binaryResult := result.GetBinaryVector()
|
||||
numBytes := dim / 8
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < numBytes; j++ {
|
||||
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalcDistanceTask_arrangeVectorsByIntID(t *testing.T) {
|
||||
task := &calcDistanceTask{}
|
||||
|
||||
inputIds := make([]int64, 0)
|
||||
inputIds = append(inputIds, 2)
|
||||
inputIds = append(inputIds, 0)
|
||||
inputIds = append(inputIds, 1)
|
||||
|
||||
sequence := make(map[int64]int)
|
||||
sequence[0] = 0
|
||||
sequence[1] = 1
|
||||
sequence[2] = 2
|
||||
|
||||
dim := 16
|
||||
|
||||
// float vector
|
||||
floatValue := make([]float32, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
floatValue = append(floatValue, float32(i))
|
||||
}
|
||||
retrievedVectors := &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: floatValue,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
|
||||
assert.NoError(t, err)
|
||||
|
||||
floatResult := result.GetFloatVector().GetData()
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < dim; j++ {
|
||||
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
|
||||
}
|
||||
}
|
||||
|
||||
// binary vector
|
||||
binaryValue := make([]byte, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
binaryValue = append(binaryValue, byte(i))
|
||||
}
|
||||
retrievedVectors = &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
}
|
||||
|
||||
result, err = task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
|
||||
assert.NoError(t, err)
|
||||
|
||||
binaryResult := result.GetBinaryVector()
|
||||
numBytes := dim / 8
|
||||
for i := 0; i < 3; i++ {
|
||||
for j := 0; j < numBytes; j++ {
|
||||
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalcDistanceTask_ExecuteFloat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
||||
return nil, errors.New("unexpected error")
|
||||
}
|
||||
|
||||
task := &calcDistanceTask{
|
||||
traceID: "dummy",
|
||||
queryFunc: queryFunc,
|
||||
}
|
||||
|
||||
request := &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: nil,
|
||||
OpRight: nil,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "L2"},
|
||||
},
|
||||
}
|
||||
|
||||
// left-op empty
|
||||
calcResult, err := task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{},
|
||||
},
|
||||
},
|
||||
OpRight: nil,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "L2"},
|
||||
},
|
||||
}
|
||||
|
||||
// left-op query error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
fieldIds := make([]int64, 0)
|
||||
fieldIds = append(fieldIds, 2)
|
||||
fieldIds = append(fieldIds, 0)
|
||||
fieldIds = append(fieldIds, 1)
|
||||
|
||||
dim := 8
|
||||
floatValue := make([]float32, 0)
|
||||
for i := 0; i < dim*3; i++ {
|
||||
floatValue = append(floatValue, float32(i))
|
||||
}
|
||||
|
||||
queryFunc = func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
||||
if ids == nil {
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unexpected",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "id",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: floatValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
task.queryFunc = queryFunc
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
idArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{
|
||||
FieldName: "vec",
|
||||
IdArray: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: idArray,
|
||||
OpRight: idArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "L2"},
|
||||
},
|
||||
}
|
||||
|
||||
// success
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
// right-op query error
|
||||
request.OpRight = nil
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request.OpRight = &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{
|
||||
FieldName: "kkk",
|
||||
IdArray: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// right-op arrange error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request.OpRight = &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// different dimension
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request.OpRight = &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: make([]float32, 0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// calcdistance return error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
}
|
||||
|
||||
func TestCalcDistanceTask_ExecuteBinary(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
fieldIds := make([]int64, 0)
|
||||
fieldIds = append(fieldIds, 2)
|
||||
fieldIds = append(fieldIds, 0)
|
||||
fieldIds = append(fieldIds, 1)
|
||||
|
||||
dim := 16
|
||||
binaryValue := make([]byte, 0)
|
||||
for i := 0; i < 3*dim/8; i++ {
|
||||
binaryValue = append(binaryValue, byte(i))
|
||||
}
|
||||
|
||||
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
|
||||
if ids == nil {
|
||||
return &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unexpected",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &milvuspb.QueryResults{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "id",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
idArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_IdArray{
|
||||
IdArray: &milvuspb.VectorIDs{
|
||||
FieldName: "vec",
|
||||
IdArray: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: fieldIds,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
request := &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: idArray,
|
||||
OpRight: idArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
task := &calcDistanceTask{
|
||||
traceID: "dummy",
|
||||
queryFunc: queryFunc,
|
||||
}
|
||||
|
||||
// success
|
||||
calcResult, err := task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
floatArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{},
|
||||
},
|
||||
},
|
||||
}
|
||||
binaryArray := &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: binaryValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: floatArray,
|
||||
OpRight: binaryArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
// float vs binary
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: binaryArray,
|
||||
OpRight: binaryArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
// hamming
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: binaryArray,
|
||||
OpRight: binaryArray,
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "TANIMOTO"},
|
||||
},
|
||||
}
|
||||
|
||||
// tanimoto
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
|
||||
|
||||
request = &milvuspb.CalcDistanceRequest{
|
||||
OpLeft: binaryArray,
|
||||
OpRight: &milvuspb.VectorsArray{
|
||||
Array: &milvuspb.VectorsArray_DataArray{
|
||||
DataArray: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: make([]byte, 0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: "metric", Value: "HAMMING"},
|
||||
},
|
||||
}
|
||||
|
||||
// hamming error
|
||||
calcResult, err = task.Execute(ctx, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
|
||||
}
|
|
@ -25,9 +25,9 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
|
@ -898,7 +898,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
|||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !distance.PositivelyRelated(metricType) {
|
||||
if !metric.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
}
|
||||
|
|
|
@ -37,8 +37,8 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
@ -121,7 +121,7 @@ func getValidSearchParams() []*commonpb.KeyValuePair {
|
|||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
|
@ -1441,7 +1441,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, test.offset)
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
|
||||
|
@ -1481,7 +1481,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
|
||||
for _, test := range lessThanLimitTests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, test.offset)
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
|
||||
|
@ -1505,7 +1505,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
results = append(results, r)
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, 0)
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, 0)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
|
@ -1532,7 +1532,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||
results = append(results, r)
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_VarChar, 0)
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_VarChar, 0)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
|
||||
|
@ -1717,7 +1717,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
noSearchParams := getBaseSearchParams()
|
||||
noSearchParams = append(noSearchParams, &commonpb.KeyValuePair{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
})
|
||||
|
||||
offsetParam := getValidSearchParams()
|
||||
|
@ -1775,7 +1775,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
|
||||
spNoSearchParams := append(spNoMetricType, &commonpb.KeyValuePair{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
})
|
||||
|
||||
// no roundDecimal is valid
|
||||
|
|
|
@ -44,9 +44,9 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
|
@ -383,7 +383,7 @@ func constructSearchRequest(
|
|||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
|
|
|
@ -1,283 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package distance
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// L2 represents the Euclidean distance
|
||||
L2 = "L2"
|
||||
// IP represents the inner product distance
|
||||
IP = "IP"
|
||||
// COSINE represents the cosine distance
|
||||
COSINE = "COSINE"
|
||||
// HAMMING represents the hamming distance
|
||||
HAMMING = "HAMMING"
|
||||
// TANIMOTO represents the tanimoto distance
|
||||
TANIMOTO = "TANIMOTO"
|
||||
// JACCARD in string
|
||||
JACCARD = "JACCARD"
|
||||
// SUPERSTRUCTURE in string
|
||||
SUPERSTRUCTURE = "SUPERSTRUCTURE"
|
||||
// SUBSTRUCTURE in string
|
||||
SUBSTRUCTURE = "SUBSTRUCTURE"
|
||||
)
|
||||
|
||||
// ValidateMetricType returns metric text or error
|
||||
func ValidateMetricType(metric string) (string, error) {
|
||||
if metric == "" {
|
||||
err := errors.New("metric type is empty")
|
||||
return "", err
|
||||
}
|
||||
|
||||
m := strings.ToUpper(metric)
|
||||
if m == L2 || m == IP || m == HAMMING || m == TANIMOTO {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
err := errors.New("invalid metric type")
|
||||
return metric, err
|
||||
}
|
||||
|
||||
// ValidateFloatArrayLength is used validate float vector length
|
||||
func ValidateFloatArrayLength(dim int64, length int) error {
|
||||
if length == 0 || int64(length)%dim != 0 {
|
||||
err := errors.New("invalid float vector length")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalcL2 returns the Euclidean distance of input vectors
|
||||
func CalcL2(dim int64, left []float32, lIndex int64, right []float32, rIndex int64) float32 {
|
||||
var sum float32
|
||||
lFrom := lIndex * dim
|
||||
rFrom := rIndex * dim
|
||||
for i := int64(0); i < dim; i++ {
|
||||
gap := left[lFrom+i] - right[rFrom+i]
|
||||
sum += gap * gap
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// CalcIP returns the inner product distance of input vectors
|
||||
func CalcIP(dim int64, left []float32, lIndex int64, right []float32, rIndex int64) float32 {
|
||||
var sum float32
|
||||
lFrom := lIndex * dim
|
||||
rFrom := rIndex * dim
|
||||
for i := int64(0); i < dim; i++ {
|
||||
sum += left[lFrom+i] * right[rFrom+i]
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// CalcFFBatch calculate the distance of @left & @right vectors in batch by given @metic, store result in @result
|
||||
func CalcFFBatch(dim int64, left []float32, lIndex int64, right []float32, metric string, result *[]float32) {
|
||||
rightNum := int64(len(right)) / dim
|
||||
for i := int64(0); i < rightNum; i++ {
|
||||
var distance float32 = -1.0
|
||||
if metric == L2 {
|
||||
distance = CalcL2(dim, left, lIndex, right, i)
|
||||
} else if metric == IP {
|
||||
distance = CalcIP(dim, left, lIndex, right, i)
|
||||
}
|
||||
(*result)[lIndex*rightNum+i] = distance
|
||||
}
|
||||
}
|
||||
|
||||
// CalcFloatDistance calculate float distance by given metric
|
||||
// it will checks input, and calculate the distance concurrently
|
||||
func CalcFloatDistance(dim int64, left, right []float32, metric string) ([]float32, error) {
|
||||
if dim <= 0 {
|
||||
err := errors.New("invalid dimension")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricUpper := strings.ToUpper(metric)
|
||||
if metricUpper != L2 && metricUpper != IP {
|
||||
err := errors.New("invalid metric type")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err := ValidateFloatArrayLength(dim, len(left))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = ValidateFloatArrayLength(dim, len(right))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leftNum := int64(len(left)) / dim
|
||||
rightNum := int64(len(right)) / dim
|
||||
|
||||
distArray := make([]float32, leftNum*rightNum)
|
||||
|
||||
// Multi-threads to calculate distance. TODO: avoid too many go routines
|
||||
var waitGroup sync.WaitGroup
|
||||
CalcWorker := func(index int64) {
|
||||
CalcFFBatch(dim, left, index, right, metricUpper, &distArray)
|
||||
waitGroup.Done()
|
||||
}
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
waitGroup.Add(1)
|
||||
go CalcWorker(i)
|
||||
}
|
||||
waitGroup.Wait()
|
||||
|
||||
return distArray, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SingleBitLen returns the bit length of @dim
|
||||
func SingleBitLen(dim int64) int64 {
|
||||
if dim%8 == 0 {
|
||||
return dim
|
||||
}
|
||||
|
||||
return dim + 8 - dim%8
|
||||
}
|
||||
|
||||
// VectorCount counts bits by @dim & @length
|
||||
func VectorCount(dim int64, length int) int64 {
|
||||
singleBitLen := SingleBitLen(dim)
|
||||
return int64(length*8) / singleBitLen
|
||||
}
|
||||
|
||||
// ValidateBinaryArrayLength validates a binary array of @dim & @length
|
||||
func ValidateBinaryArrayLength(dim int64, length int) error {
|
||||
singleBitLen := SingleBitLen(dim)
|
||||
totalBitLen := int64(length * 8)
|
||||
if length == 0 || totalBitLen%singleBitLen != 0 {
|
||||
err := errors.New("invalid binary vector length")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountOne count 1 of uint8
|
||||
// For 00000010, return 1
|
||||
// Fro 11111111, return 8
|
||||
func CountOne(n uint8) int32 {
|
||||
count := int32(0)
|
||||
for n != 0 {
|
||||
count++
|
||||
n = n & (n - 1)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// CalcHamming calculate HAMMING distance
|
||||
func CalcHamming(dim int64, left []byte, lIndex int64, right []byte, rIndex int64) int32 {
|
||||
singleBitLen := SingleBitLen(dim)
|
||||
numBytes := singleBitLen / 8
|
||||
lFrom := lIndex * numBytes
|
||||
rFrom := rIndex * numBytes
|
||||
|
||||
var hamming int32
|
||||
for i := int64(0); i < numBytes; i++ {
|
||||
var xor = left[lFrom+i] ^ right[rFrom+i]
|
||||
|
||||
// The dimension "dim" may not be an integer multiple of 8
|
||||
// For example:
|
||||
// dim = 11, each vector has 2 uint8 value
|
||||
// the second uint8, only need to calculate 3 bits, the other 5 bits will be set to 0
|
||||
if i == numBytes-1 && numBytes*8 > dim {
|
||||
offset := numBytes*8 - dim
|
||||
xor = xor & (255 << offset)
|
||||
}
|
||||
|
||||
hamming += CountOne(xor)
|
||||
}
|
||||
|
||||
return hamming
|
||||
}
|
||||
|
||||
// CalcHammingBatch calculate HAMMING distance in batch, results are in @result
|
||||
func CalcHammingBatch(dim int64, left []byte, lIndex int64, right []byte, result *[]int32) {
|
||||
rightNum := VectorCount(dim, len(right))
|
||||
|
||||
for i := int64(0); i < rightNum; i++ {
|
||||
hamming := CalcHamming(dim, left, lIndex, right, i)
|
||||
(*result)[lIndex*rightNum+i] = hamming
|
||||
}
|
||||
}
|
||||
|
||||
func CalcHammingDistance(dim int64, left, right []byte) ([]int32, error) {
|
||||
if dim <= 0 {
|
||||
err := errors.New("invalid dimension")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err := ValidateBinaryArrayLength(dim, len(left))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = ValidateBinaryArrayLength(dim, len(right))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leftNum := VectorCount(dim, len(left))
|
||||
rightNum := VectorCount(dim, len(right))
|
||||
distArray := make([]int32, leftNum*rightNum)
|
||||
|
||||
// Multi-threads to calculate distance. TODO: avoid too many go routines
|
||||
var waitGroup sync.WaitGroup
|
||||
CalcWorker := func(index int64) {
|
||||
CalcHammingBatch(dim, left, index, right, &distArray)
|
||||
waitGroup.Done()
|
||||
}
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
waitGroup.Add(1)
|
||||
go CalcWorker(i)
|
||||
}
|
||||
waitGroup.Wait()
|
||||
|
||||
return distArray, nil
|
||||
}
|
||||
|
||||
func CalcTanimotoCoefficient(dim int64, hamming []int32) ([]float32, error) {
|
||||
if dim <= 0 || len(hamming) == 0 {
|
||||
err := errors.New("invalid input for tanimoto")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
array := make([]float32, len(hamming))
|
||||
for i := 0; i < len(hamming); i++ {
|
||||
if hamming[i] > int32(dim) {
|
||||
err := errors.New("invalid hamming for tanimoto")
|
||||
return nil, err
|
||||
}
|
||||
equalBits := int32(dim) - hamming[i]
|
||||
array[i] = float32(equalBits) / (float32(dim)*2 - float32(equalBits))
|
||||
}
|
||||
|
||||
return array, nil
|
||||
}
|
|
@ -1,291 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package distance
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const PRECISION = 1e-6
|
||||
|
||||
func TestValidateMetricType(t *testing.T) {
|
||||
invalidMetric := []string{"", "aaa"}
|
||||
for _, str := range invalidMetric {
|
||||
_, err := ValidateMetricType(str)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
validMetric := []string{"L2", "ip", "Hamming", "Tanimoto"}
|
||||
for _, str := range validMetric {
|
||||
metric, err := ValidateMetricType(str)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, metric == L2 || metric == IP || metric == HAMMING || metric == TANIMOTO)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFloatArrayLength(t *testing.T) {
|
||||
err := ValidateFloatArrayLength(3, 12)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = ValidateFloatArrayLength(5, 11)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func CreateFloatArray(n, dim int64) []float32 {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
num := n * dim
|
||||
array := make([]float32, num)
|
||||
for i := int64(0); i < num; i++ {
|
||||
array[i] = rand.Float32()
|
||||
}
|
||||
|
||||
return array
|
||||
}
|
||||
|
||||
func DistanceL2(left, right []float32) float32 {
|
||||
if len(left) != len(right) {
|
||||
panic("array dimension not equal")
|
||||
}
|
||||
var sum float32
|
||||
for i := 0; i < len(left); i++ {
|
||||
gap := left[i] - right[i]
|
||||
sum += gap * gap
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
func DistanceIP(left, right []float32) float32 {
|
||||
if len(left) != len(right) {
|
||||
panic("array dimension not equal")
|
||||
}
|
||||
var sum float32
|
||||
for i := 0; i < len(left); i++ {
|
||||
sum += left[i] * right[i]
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
func Test_CalcL2(t *testing.T) {
|
||||
var dim int64 = 128
|
||||
var leftNum int64 = 1
|
||||
var rightNum int64 = 1
|
||||
|
||||
left := CreateFloatArray(leftNum, dim)
|
||||
right := CreateFloatArray(rightNum, dim)
|
||||
|
||||
sum := DistanceL2(left, right)
|
||||
|
||||
distance := CalcL2(dim, left, 0, right, 0)
|
||||
assert.Less(t, math.Abs(float64(sum-distance)), PRECISION)
|
||||
|
||||
distance = CalcL2(dim, left, 0, left, 0)
|
||||
assert.Less(t, float64(distance), PRECISION)
|
||||
}
|
||||
|
||||
func Test_CalcIP(t *testing.T) {
|
||||
var dim int64 = 128
|
||||
var leftNum int64 = 1
|
||||
var rightNum int64 = 1
|
||||
|
||||
left := CreateFloatArray(leftNum, dim)
|
||||
right := CreateFloatArray(rightNum, dim)
|
||||
|
||||
sum := DistanceIP(left, right)
|
||||
|
||||
distance := CalcIP(dim, left, 0, right, 0)
|
||||
assert.Less(t, math.Abs(float64(sum-distance)), PRECISION)
|
||||
}
|
||||
|
||||
func Test_CalcFloatDistance(t *testing.T) {
|
||||
var dim int64 = 128
|
||||
var leftNum int64 = 10
|
||||
var rightNum int64 = 5
|
||||
|
||||
left := CreateFloatArray(leftNum, dim)
|
||||
right := CreateFloatArray(rightNum, dim)
|
||||
|
||||
// Verify illegal cases
|
||||
_, err := CalcFloatDistance(dim, left, right, "HAMMIN")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = CalcFloatDistance(3, left, right, "L2")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = CalcFloatDistance(dim, left, right, "HAMMIN")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = CalcFloatDistance(0, left, right, "L2")
|
||||
assert.Error(t, err)
|
||||
|
||||
distances, err := CalcFloatDistance(dim, left, right, "L2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify the L2 distance algorithm is correct
|
||||
invalid := CreateFloatArray(rightNum, 10)
|
||||
_, err = CalcFloatDistance(dim, left, invalid, "L2")
|
||||
assert.Error(t, err)
|
||||
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
for j := int64(0); j < rightNum; j++ {
|
||||
v1 := left[i*dim : (i+1)*dim]
|
||||
v2 := right[j*dim : (j+1)*dim]
|
||||
sum := DistanceL2(v1, v2)
|
||||
assert.Less(t, math.Abs(float64(sum-distances[i*rightNum+j])), PRECISION)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the IP distance algorithm is correct
|
||||
distances, err = CalcFloatDistance(dim, left, right, "IP")
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := int64(0); i < leftNum; i++ {
|
||||
for j := int64(0); j < rightNum; j++ {
|
||||
v1 := left[i*dim : (i+1)*dim]
|
||||
v2 := right[j*dim : (j+1)*dim]
|
||||
sum := DistanceIP(v1, v2)
|
||||
assert.Less(t, math.Abs(float64(sum-distances[i*rightNum+j])), PRECISION)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// //////////////////////////////////////////////////////////////////////////////
|
||||
func CreateBinaryArray(n, dim int64) []byte {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
num := n * dim / 8
|
||||
if num*8 < n*dim {
|
||||
num = num + 1
|
||||
}
|
||||
array := make([]byte, num)
|
||||
for i := int64(0); i < num; i++ {
|
||||
n := rand.Intn(256)
|
||||
array[i] = uint8(n)
|
||||
}
|
||||
|
||||
return array
|
||||
}
|
||||
|
||||
func Test_SingleBitLen(t *testing.T) {
|
||||
n := SingleBitLen(125)
|
||||
assert.Equal(t, n, int64(128))
|
||||
|
||||
n = SingleBitLen(133)
|
||||
assert.Equal(t, n, int64(136))
|
||||
}
|
||||
|
||||
func Test_VectorCount(t *testing.T) {
|
||||
n := VectorCount(15, 20)
|
||||
assert.Equal(t, n, int64(10))
|
||||
|
||||
n = VectorCount(8, 3)
|
||||
assert.Equal(t, n, int64(3))
|
||||
}
|
||||
|
||||
func Test_ValidateBinaryArrayLength(t *testing.T) {
|
||||
err := ValidateBinaryArrayLength(21, 12)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = ValidateBinaryArrayLength(21, 11)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_CountOne(t *testing.T) {
|
||||
n := CountOne(6)
|
||||
assert.Equal(t, n, int32(2))
|
||||
|
||||
n = CountOne(0)
|
||||
assert.Equal(t, n, int32(0))
|
||||
|
||||
n = CountOne(255)
|
||||
assert.Equal(t, n, int32(8))
|
||||
}
|
||||
|
||||
func Test_CalcHamming(t *testing.T) {
|
||||
var dim int64 = 22
|
||||
// v1 = 00000010 00000110 00001000
|
||||
v1 := make([]uint8, 3)
|
||||
v1[0] = 2
|
||||
v1[1] = 6
|
||||
v1[2] = 8
|
||||
// v2 = 00000001 00000111 00011011
|
||||
v2 := make([]uint8, 3)
|
||||
v2[0] = 1
|
||||
v2[1] = 7
|
||||
v2[2] = 27
|
||||
n := CalcHamming(dim, v1, 0, v2, 0)
|
||||
assert.Equal(t, n, int32(4))
|
||||
}
|
||||
|
||||
func Test_CalcHamminDistance(t *testing.T) {
|
||||
var dim int64 = 125
|
||||
var leftNum int64 = 2
|
||||
|
||||
left := CreateBinaryArray(leftNum, dim)
|
||||
|
||||
_, e := CalcHammingDistance(0, left, left)
|
||||
assert.Error(t, e)
|
||||
|
||||
distances, err := CalcHammingDistance(dim, left, left)
|
||||
assert.NoError(t, err)
|
||||
|
||||
n := CalcHamming(dim, left, 0, left, 0)
|
||||
assert.Equal(t, n, int32(0))
|
||||
|
||||
n = CalcHamming(dim, left, 1, left, 1)
|
||||
assert.Equal(t, n, int32(0))
|
||||
|
||||
n = CalcHamming(dim, left, 0, left, 1)
|
||||
assert.Equal(t, n, distances[1])
|
||||
|
||||
n = CalcHamming(dim, left, 1, left, 0)
|
||||
assert.Equal(t, n, distances[2])
|
||||
|
||||
invalid := CreateBinaryArray(leftNum, 200)
|
||||
_, e = CalcHammingDistance(dim, invalid, left)
|
||||
assert.Error(t, e)
|
||||
|
||||
_, e = CalcHammingDistance(dim, left, invalid)
|
||||
assert.Error(t, e)
|
||||
}
|
||||
|
||||
func Test_CalcTanimotoCoefficient(t *testing.T) {
|
||||
var dim int64 = 22
|
||||
hamming := make([]int32, 2)
|
||||
hamming[0] = 4
|
||||
hamming[1] = 17
|
||||
tanimoto, err := CalcTanimotoCoefficient(dim, hamming)
|
||||
|
||||
for i := 0; i < len(hamming); i++ {
|
||||
realTanimoto := float64(int32(dim)-hamming[i]) / (float64(dim)*2.0 - float64(int32(dim)-hamming[i]))
|
||||
assert.NoError(t, err)
|
||||
assert.Less(t, math.Abs(float64(tanimoto[i])-realTanimoto), float64(PRECISION))
|
||||
}
|
||||
|
||||
_, err = CalcTanimotoCoefficient(-1, hamming)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = CalcTanimotoCoefficient(3, hamming)
|
||||
assert.Error(t, err)
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -12,10 +13,10 @@ import (
|
|||
func Test_baseChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -12,44 +13,44 @@ import (
|
|||
func Test_binFlatChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -15,24 +16,24 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
|
||||
invalidParams := copyParams(validParams)
|
||||
invalidParams[Metric] = L2
|
||||
invalidParams[Metric] = metric.L2
|
||||
|
||||
paramsWithLargeNlist := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MaxNList + 1),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
|
||||
paramsWithSmallNlist := map[string]string{
|
||||
|
@ -40,26 +41,26 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(MinNList - 1),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
|
@ -67,21 +68,21 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
|||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
|
@ -89,14 +90,14 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
|||
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
|
|
|
@ -1,32 +1,11 @@
|
|||
package indexparamcheck
|
||||
|
||||
import "github.com/milvus-io/milvus/pkg/common"
|
||||
import (
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
const (
|
||||
// L2 represents Euclidean distance
|
||||
L2 = "L2"
|
||||
|
||||
// IP represents inner product distance
|
||||
IP = "IP"
|
||||
|
||||
// COSINE represents cosine distance
|
||||
COSINE = "COSINE"
|
||||
|
||||
// HAMMING represents hamming distance
|
||||
HAMMING = "HAMMING"
|
||||
|
||||
// JACCARD represents jaccard distance
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
// TANIMOTO represents tanimoto distance
|
||||
TANIMOTO = "TANIMOTO"
|
||||
|
||||
// SUBSTRUCTURE represents substructure distance
|
||||
SUBSTRUCTURE = "SUBSTRUCTURE"
|
||||
|
||||
// SUPERSTRUCTURE represents superstructure distance
|
||||
SUPERSTRUCTURE = "SUPERSTRUCTURE"
|
||||
|
||||
MinNBits = 1
|
||||
MaxNBits = 16
|
||||
DefaultNBits = 8
|
||||
|
@ -65,16 +44,17 @@ const (
|
|||
)
|
||||
|
||||
// METRICS is a set of all metrics types supported for float vector.
|
||||
var METRICS = []string{L2, IP, COSINE} // const
|
||||
var METRICS = []string{metric.L2, metric.IP, metric.COSINE} // const
|
||||
|
||||
// BinIDMapMetrics is a set of all metric types supported for binary vector.
|
||||
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERSTRUCTURE} // const
|
||||
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
|
||||
var HnswMetrics = []string{L2, IP, COSINE, HAMMING, JACCARD} // const
|
||||
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
||||
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
||||
var BinIDMapMetrics = []string{metric.HAMMING, metric.JACCARD, metric.TANIMOTO, metric.SUBSTRUCTURE,
|
||||
metric.SUPERSTRUCTURE} // const
|
||||
var BinIvfMetrics = []string{metric.HAMMING, metric.JACCARD, metric.TANIMOTO} // const
|
||||
var HnswMetrics = []string{metric.L2, metric.IP, metric.COSINE, metric.HAMMING, metric.JACCARD} // const
|
||||
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
||||
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
||||
|
||||
const (
|
||||
FloatVectorDefaultMetricType = IP
|
||||
BinaryVectorDefaultMetricType = JACCARD
|
||||
FloatVectorDefaultMetricType = metric.IP
|
||||
BinaryVectorDefaultMetricType = metric.JACCARD
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -12,13 +13,13 @@ import (
|
|||
func Test_diskannChecker_CheckTrain(t *testing.T) {
|
||||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
validParamsBigDim := copyParams(validParams)
|
||||
validParamsBigDim[DIM] = strconv.Itoa(2048)
|
||||
|
||||
invalidParamsWithoutDim := map[string]string{
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
invalidParamsSmallDim := copyParams(validParams)
|
||||
|
@ -26,36 +27,36 @@ func Test_diskannChecker_CheckTrain(t *testing.T) {
|
|||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -11,36 +13,36 @@ func Test_flatChecker_CheckTrain(t *testing.T) {
|
|||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -15,7 +16,7 @@ func Test_hnswChecker_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
invalidEfParamsMin := copyParams(validParams)
|
||||
|
@ -34,50 +35,50 @@ func Test_hnswChecker_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
HNSWM: strconv.Itoa(16),
|
||||
EFConstruction: strconv.Itoa(200),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
|
|
|
@ -13,6 +13,8 @@ package indexparamcheck
|
|||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
// TODO: add more test cases which `IndexChecker.CheckTrain` return false,
|
||||
|
@ -22,7 +24,7 @@ func invalidIVFParamsMin() map[string]string {
|
|||
invalidIVFParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MinNList - 1),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
return invalidIVFParams
|
||||
}
|
||||
|
@ -31,7 +33,7 @@ func invalidIVFParamsMax() map[string]string {
|
|||
invalidIVFParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(MaxNList + 1),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
return invalidIVFParams
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -13,49 +14,49 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) {
|
|||
validParams := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -15,7 +16,7 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
paramsNotMultiplier := map[string]string{
|
||||
|
@ -23,21 +24,21 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(5),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
validParamsWithoutNbits := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
validParamsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
invalidParamsDim := copyParams(validParams)
|
||||
|
@ -50,7 +51,7 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
invalidParamsIVF := copyParams(validParams)
|
||||
|
@ -67,21 +68,21 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
|
@ -89,35 +90,35 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -15,7 +16,7 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
if withNBits {
|
||||
validParams[NBITS] = strconv.Itoa(DefaultNBits)
|
||||
|
@ -31,50 +32,50 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(100),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -16,21 +17,21 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
validParamsWithoutNbits := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
validParamsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
invalidParamsDim := copyParams(validParams)
|
||||
|
@ -43,7 +44,7 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
|
||||
invalidParamsIVF := copyParams(validParams)
|
||||
|
@ -60,21 +61,21 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: L2,
|
||||
Metric: metric.L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: IP,
|
||||
Metric: metric.IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: COSINE,
|
||||
Metric: metric.COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
|
@ -82,35 +83,35 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
|
|||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: HAMMING,
|
||||
Metric: metric.HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
Metric: metric.JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: TANIMOTO,
|
||||
Metric: metric.TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUBSTRUCTURE,
|
||||
Metric: metric.SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
NLIST: strconv.Itoa(1024),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
Metric: metric.SUPERSTRUCTURE,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package metric
|
||||
|
||||
// MetricType string.
|
||||
type MetricType = string
|
||||
|
||||
// MetricType definitions
|
||||
const (
|
||||
// L2 represents Euclidean distance
|
||||
L2 MetricType = "L2"
|
||||
|
||||
// IP represents inner product distance
|
||||
IP MetricType = "IP"
|
||||
|
||||
// COSINE represents cosine distance
|
||||
COSINE MetricType = "COSINE"
|
||||
|
||||
// HAMMING represents hamming distance
|
||||
HAMMING MetricType = "HAMMING"
|
||||
|
||||
// JACCARD represents jaccard distance
|
||||
JACCARD MetricType = "JACCARD"
|
||||
|
||||
// TANIMOTO represents tanimoto distance
|
||||
TANIMOTO MetricType = "TANIMOTO"
|
||||
|
||||
// SUBSTRUCTURE represents substructure distance
|
||||
SUBSTRUCTURE MetricType = "SUBSTRUCTURE"
|
||||
|
||||
// SUPERSTRUCTURE represents superstructure distance
|
||||
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
||||
)
|
|
@ -14,7 +14,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package distance
|
||||
package metric
|
||||
|
||||
import "strings"
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package distance
|
||||
package metric
|
||||
|
||||
import "testing"
|
||||
|
|
@ -36,8 +36,8 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/util/importutil"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
|
@ -164,7 +164,7 @@ func (s *BulkInsertSuite) TestBulkInsert() {
|
|||
CollectionName: collectionName,
|
||||
FieldName: "embeddings",
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexHNSW, distance.L2),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexHNSW, metric.L2),
|
||||
})
|
||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||
|
@ -192,9 +192,9 @@ func (s *BulkInsertSuite) TestBulkInsert() {
|
|||
topk := 10
|
||||
roundDecimal := -1
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexHNSW, distance.L2)
|
||||
params := integration.GetSearchParams(integration.IndexHNSW, metric.L2)
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
"embeddings", schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
"embeddings", schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
@ -258,7 +258,7 @@ func (s *TestGetVectorSuite) TestGetVector_FLAT() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexFaissIDMap
|
||||
s.metricType = distance.L2
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
|
@ -269,7 +269,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexFaissIvfFlat
|
||||
s.metricType = distance.L2
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
|
@ -280,7 +280,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_PQ() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexFaissIvfPQ
|
||||
s.metricType = distance.L2
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = true
|
||||
|
@ -291,7 +291,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexFaissIvfSQ8
|
||||
s.metricType = distance.L2
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = true
|
||||
|
@ -302,7 +302,7 @@ func (s *TestGetVectorSuite) TestGetVector_HNSW() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexHNSW
|
||||
s.metricType = distance.L2
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
|
@ -313,7 +313,7 @@ func (s *TestGetVectorSuite) TestGetVector_IP() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexHNSW
|
||||
s.metricType = distance.IP
|
||||
s.metricType = metric.IP
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
|
@ -324,7 +324,7 @@ func (s *TestGetVectorSuite) TestGetVector_StringPK() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexHNSW
|
||||
s.metricType = distance.L2
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_VarChar
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
|
@ -335,7 +335,7 @@ func (s *TestGetVectorSuite) TestGetVector_BinaryVector() {
|
|||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexFaissBinIvfFlat
|
||||
s.metricType = distance.JACCARD
|
||||
s.metricType = metric.JACCARD
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_BinaryVector
|
||||
s.searchFailed = false
|
||||
|
@ -347,7 +347,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
|
|||
s.nq = 10000
|
||||
s.topK = 200
|
||||
s.indexType = integration.IndexHNSW
|
||||
s.metricType = distance.L2
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
|
|
|
@ -30,8 +30,8 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
|
@ -110,7 +110,7 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
|
|||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
|
||||
})
|
||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||
|
@ -138,9 +138,9 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
|
|||
topk := 10
|
||||
roundDecimal := -1
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
|
@ -78,7 +78,7 @@ func (s *GetIndexStatisticsSuite) TestGetIndexStatistics() {
|
|||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
|
||||
})
|
||||
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||
|
|
|
@ -28,9 +28,9 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
|
@ -77,7 +77,7 @@ func (s *InsertSuite) TestInsert() {
|
|||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.IP),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
|
|
|
@ -26,7 +26,6 @@ import (
|
|||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
|
@ -36,6 +35,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
@ -776,7 +776,7 @@ func (s *JSONExprSuite) insertFlushIndexLoad(ctx context.Context, dbName, collec
|
|||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
},
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
|
@ -824,9 +824,9 @@ func (s *JSONExprSuite) doSearch(collectionName string, outputField []string, ex
|
|||
topk := 10
|
||||
roundDecimal := -1
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := s.Cluster.Proxy.Search(context.Background(), searchReq)
|
||||
|
||||
|
@ -899,9 +899,9 @@ func (s *JSONExprSuite) doSearchWithInvalidExpr(collectionName string, outputFie
|
|||
topk := 10
|
||||
roundDecimal := -1
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := s.Cluster.Proxy.Search(context.Background(), searchReq)
|
||||
|
||||
|
|
|
@ -32,8 +32,8 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
type MetaWatcherSuite struct {
|
||||
|
@ -271,7 +271,7 @@ func (s *MetaWatcherSuite) TestShowReplicas() {
|
|||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
Value: metric.L2,
|
||||
},
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
|
|
|
@ -29,9 +29,9 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
|
@ -109,7 +109,7 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
|
|||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.IP),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
|
@ -137,12 +137,12 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
|
|||
radius := 10
|
||||
filter := 20
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.IP)
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
|
||||
|
||||
// only pass in radius when range search
|
||||
params["radius"] = radius
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ := c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -155,7 +155,7 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
|
|||
// pass in radius and range_filter when range search
|
||||
params["range_filter"] = filter
|
||||
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -169,7 +169,7 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
|
|||
params["radius"] = filter
|
||||
params["range_filter"] = radius
|
||||
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -257,7 +257,7 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
|
|||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
|
@ -285,11 +285,11 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
|
|||
radius := 20
|
||||
filter := 10
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
|
||||
// only pass in radius when range search
|
||||
params["radius"] = radius
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ := c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -302,7 +302,7 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
|
|||
// pass in radius and range_filter when range search
|
||||
params["range_filter"] = filter
|
||||
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
@ -316,7 +316,7 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
|
|||
params["radius"] = filter
|
||||
params["range_filter"] = radius
|
||||
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ = c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
@ -125,7 +125,7 @@ func (s *RefreshConfigSuite) TestRefreshDefaultIndexName() {
|
|||
_, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
|
|
|
@ -26,9 +26,9 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/distance"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
|
@ -109,7 +109,7 @@ func (s *UpsertSuite) TestUpsert() {
|
|||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.IP),
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
|
@ -138,7 +138,7 @@ func (s *UpsertSuite) TestUpsert() {
|
|||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, "")
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ := c.Proxy.Search(ctx, searchReq)
|
||||
|
||||
|
|
|
@ -2667,6 +2667,7 @@ class TestCollectionSearch(TestcaseBase):
|
|||
assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip("tanimoto obsolete")
|
||||
@pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"])
|
||||
def test_search_binary_tanimoto_flat_index(self, nq, dim, auto_id, _async, index, is_flush):
|
||||
"""
|
||||
|
@ -2706,6 +2707,7 @@ class TestCollectionSearch(TestcaseBase):
|
|||
assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip("substructure obsolete")
|
||||
@pytest.mark.parametrize("index", ["BIN_FLAT"])
|
||||
def test_search_binary_substructure_flat_index(self, auto_id, _async, index, is_flush):
|
||||
"""
|
||||
|
@ -2742,6 +2744,7 @@ class TestCollectionSearch(TestcaseBase):
|
|||
assert len(res) <= default_limit
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip("superstructure obsolete")
|
||||
@pytest.mark.parametrize("index", ["BIN_FLAT"])
|
||||
def test_search_binary_superstructure_flat_index(self, auto_id, _async, index, is_flush):
|
||||
"""
|
||||
|
@ -6656,6 +6659,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
"limit": 0})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip("tanimoto obsolete")
|
||||
@pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"])
|
||||
def test_range_search_binary_tanimoto_flat_index(self, dim, auto_id, _async, index, is_flush):
|
||||
"""
|
||||
|
@ -6711,6 +6715,7 @@ class TestCollectionRangeSearch(TestcaseBase):
|
|||
assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.skip("tanimoto obsolete")
|
||||
@pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"])
|
||||
def test_range_search_binary_tanimoto_invalid_params(self, index):
|
||||
"""
|
||||
|
|
|
@ -676,7 +676,7 @@ class TestUtilityBase(TestcaseBase):
|
|||
def metric(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=["HAMMING", "TANIMOTO"])
|
||||
@pytest.fixture(scope="function", params=["HAMMING", "JACCARD"])
|
||||
def metric_binary(self, request):
|
||||
yield request.param
|
||||
|
||||
|
|
Loading…
Reference in New Issue