Remove calc_distance (#25663)

Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
pull/25707/head
Cai Yudong 2023-07-18 14:23:20 +08:00 committed by GitHub
parent 9a61c0291b
commit 73512c72fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 240 additions and 1753 deletions

View File

@ -3150,88 +3150,12 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR
// CalcDistance calculates the distances between vectors.
func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) {
if !node.checkHealthy() {
return &milvuspb.CalcDistanceResults{
Status: unhealthyStatus(),
}, nil
}
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CalcDistance")
defer sp.End()
query := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
outputFields := []string{ids.FieldName}
queryRequest := &milvuspb.QueryRequest{
DbName: "",
CollectionName: ids.CollectionName,
PartitionNames: ids.PartitionNames,
OutputFields: outputFields,
}
qt := &queryTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
},
request: queryRequest,
qc: node.queryCoord,
ids: ids.IdArray,
}
log := log.Ctx(ctx).With(
zap.String("collection", queryRequest.CollectionName),
zap.Any("partitions", queryRequest.PartitionNames),
zap.Any("OutputFields", queryRequest.OutputFields))
err := node.sched.dqQueue.Enqueue(qt)
if err != nil {
log.Error("CalcDistance queryTask failed to enqueue",
zap.Error(err))
return &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
log.Debug("CalcDistance queryTask enqueued")
err = qt.WaitToFinish()
if err != nil {
log.Error("CalcDistance queryTask failed to WaitToFinish",
zap.Error(err))
return &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, err
}
log.Debug("CalcDistance queryTask Done")
return &milvuspb.QueryResults{
Status: qt.result.Status,
FieldsData: qt.result.FieldsData,
}, nil
}
// calcDistanceTask is not a standard task, no need to enqueue
task := &calcDistanceTask{
traceID: sp.SpanContext().TraceID().String(),
queryFunc: query,
}
return task.Execute(ctx, request)
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "interface obsolete",
},
}, nil
}
// FlushAll notifies Proxy to flush all collection's DML messages.

View File

@ -63,9 +63,9 @@ import (
"github.com/milvus-io/milvus/pkg/tracer"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -639,7 +639,7 @@ func TestProxy(t *testing.T) {
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
},
{
Key: common.IndexTypeKey,
@ -1543,7 +1543,7 @@ func TestProxy(t *testing.T) {
Params: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
},
},
})

View File

@ -1,412 +0,0 @@
package proxy
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type calcDistanceTask struct {
traceID string
queryFunc func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error)
}
func (t *calcDistanceTask) arrangeVectorsByIntID(inputIds []int64, sequence map[int64]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
if retrievedVectors.GetFloatVector() != nil {
floatArr := retrievedVectors.GetFloatVector().GetData()
element := retrievedVectors.GetDim()
result := make([]float32, 0, int64(len(inputIds))*element)
for _, id := range inputIds {
index, ok := sequence[id]
if !ok {
log.Error("id not found in CalcDistance", zap.Int64("id", id))
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
}
return &schemapb.VectorField{
Dim: element,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: result,
},
},
}, nil
}
if retrievedVectors.GetBinaryVector() != nil {
binaryArr := retrievedVectors.GetBinaryVector()
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
numBytes := singleBitLen / 8
result := make([]byte, 0, int64(len(inputIds))*numBytes)
for _, id := range inputIds {
index, ok := sequence[id]
if !ok {
log.Error("id not found in CalcDistance", zap.Int64("id", id))
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
}
return &schemapb.VectorField{
Dim: retrievedVectors.GetDim(),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: result,
},
}, nil
}
return nil, errors.New("unsupported vector type")
}
func (t *calcDistanceTask) arrangeVectorsByStrID(inputIds []string, sequence map[string]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) {
if retrievedVectors.GetFloatVector() != nil {
floatArr := retrievedVectors.GetFloatVector().GetData()
element := retrievedVectors.GetDim()
result := make([]float32, 0, int64(len(inputIds))*element)
for _, id := range inputIds {
index, ok := sequence[id]
if !ok {
log.Error("id not found in CalcDistance", zap.String("id", id))
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
}
return &schemapb.VectorField{
Dim: element,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: result,
},
},
}, nil
}
if retrievedVectors.GetBinaryVector() != nil {
binaryArr := retrievedVectors.GetBinaryVector()
singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim())
numBytes := singleBitLen / 8
result := make([]byte, 0, int64(len(inputIds))*numBytes)
for _, id := range inputIds {
index, ok := sequence[id]
if !ok {
log.Error("id not found in CalcDistance", zap.String("id", id))
return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...)
}
return &schemapb.VectorField{
Dim: retrievedVectors.GetDim(),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: result,
},
}, nil
}
return nil, errors.New("unsupported vector type")
}
func (t *calcDistanceTask) Execute(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) {
param, _ := funcutil.GetAttrByKeyFromRepeatedKV("metric", request.GetParams())
metric, err := distance.ValidateMetricType(param)
if err != nil {
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
var retrievedIds *schemapb.ScalarField
var retrievedVectors *schemapb.VectorField
isStringID := true
for _, fieldData := range retrievedFields {
if fieldData.FieldName == ids.FieldName {
retrievedVectors = fieldData.GetVectors()
}
if fieldData.Type == schemapb.DataType_Int64 ||
fieldData.Type == schemapb.DataType_VarChar ||
fieldData.Type == schemapb.DataType_String {
retrievedIds = fieldData.GetScalars()
if fieldData.Type == schemapb.DataType_Int64 {
isStringID = false
}
}
}
if retrievedIds == nil || retrievedVectors == nil {
return nil, errors.New("failed to fetch vectors")
}
if isStringID {
dict := make(map[string]int)
for index, id := range retrievedIds.GetStringData().GetData() {
dict[id] = index
}
inputIds := ids.IdArray.GetStrId().GetData()
return t.arrangeVectorsByStrID(inputIds, dict, retrievedVectors)
}
dict := make(map[int64]int)
for index, id := range retrievedIds.GetLongData().GetData() {
dict[id] = index
}
inputIds := ids.IdArray.GetIntId().GetData()
return t.arrangeVectorsByIntID(inputIds, dict, retrievedVectors)
}
log := log.Ctx(ctx)
log.Debug("CalcDistance received",
zap.String("role", typeutil.ProxyRole),
zap.String("metric", metric))
vectorsLeft := request.GetOpLeft().GetDataArray()
opLeft := request.GetOpLeft().GetIdArray()
if opLeft != nil {
log.Debug("OpLeft IdArray not empty, Get vectors by id", zap.String("role", typeutil.ProxyRole))
result, err := t.queryFunc(opLeft)
if err != nil {
log.Warn("Failed to get left vectors by id",
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("OpLeft IdArray not empty, Get vectors by id done",
zap.String("role", typeutil.ProxyRole))
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
if err != nil {
log.Debug("Failed to re-arrange left vectors",
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("Re-arrange left vectors done",
zap.String("role", typeutil.ProxyRole))
}
if vectorsLeft == nil {
msg := "Left vectors array is empty"
log.Debug(msg,
zap.String("role", typeutil.ProxyRole))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msg,
},
}, nil
}
vectorsRight := request.GetOpRight().GetDataArray()
opRight := request.GetOpRight().GetIdArray()
if opRight != nil {
log.Debug("OpRight IdArray not empty, Get vectors by id",
zap.String("role", typeutil.ProxyRole))
result, err := t.queryFunc(opRight)
if err != nil {
log.Debug("Failed to get right vectors by id",
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("OpRight IdArray not empty, Get vectors by id done",
zap.String("role", typeutil.ProxyRole))
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
if err != nil {
log.Debug("Failed to re-arrange right vectors",
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("Re-arrange right vectors done",
zap.String("role", typeutil.ProxyRole))
}
if vectorsRight == nil {
msg := "Right vectors array is empty"
log.Warn(msg, zap.String("role", typeutil.ProxyRole))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msg,
},
}, nil
}
if vectorsLeft.GetDim() != vectorsRight.GetDim() {
msg := "Vectors dimension is not equal"
log.Debug(msg, zap.String("role", typeutil.ProxyRole))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: msg,
},
}, nil
}
if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
distances, err := distance.CalcFloatDistance(vectorsLeft.GetDim(), vectorsLeft.GetFloatVector().GetData(), vectorsRight.GetFloatVector().GetData(), metric)
if err != nil {
log.Warn("Failed to CalcFloatDistance",
zap.Int64("leftDim", vectorsLeft.GetDim()),
zap.Int("leftLen", len(vectorsLeft.GetFloatVector().GetData())),
zap.Int64("rightDim", vectorsRight.GetDim()),
zap.Int("rightLen", len(vectorsRight.GetFloatVector().GetData())),
zap.String("traceID", t.traceID),
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("CalcFloatDistance done",
zap.String("traceID", t.traceID),
zap.String("role", typeutil.ProxyRole))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
Array: &milvuspb.CalcDistanceResults_FloatDist{
FloatDist: &schemapb.FloatArray{
Data: distances,
},
},
}, nil
}
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
hamming, err := distance.CalcHammingDistance(vectorsLeft.GetDim(), vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
if err != nil {
log.Debug("Failed to CalcHammingDistance",
zap.Int64("leftDim", vectorsLeft.GetDim()),
zap.Int("leftLen", len(vectorsLeft.GetBinaryVector())),
zap.Int64("rightDim", vectorsRight.GetDim()),
zap.Int("rightLen", len(vectorsRight.GetBinaryVector())),
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
if metric == distance.HAMMING {
log.Debug("CalcHammingDistance done", zap.String("role", typeutil.ProxyRole))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
Array: &milvuspb.CalcDistanceResults_IntDist{
IntDist: &schemapb.IntArray{
Data: hamming,
},
},
}, nil
}
if metric == distance.TANIMOTO {
tanimoto, err := distance.CalcTanimotoCoefficient(vectorsLeft.GetDim(), hamming)
if err != nil {
log.Warn("Failed to CalcTanimotoCoefficient",
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
log.Debug("CalcTanimotoCoefficient done",
zap.String("role", typeutil.ProxyRole))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""},
Array: &milvuspb.CalcDistanceResults_FloatDist{
FloatDist: &schemapb.FloatArray{
Data: tanimoto,
},
},
}, nil
}
}
err = errors.New("unexpected error")
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
err = errors.New("cannot calculate distance between binary vectors and float vectors")
}
log.Warn("Failed to CalcDistance",
zap.String("role", typeutil.ProxyRole),
zap.Error(err))
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}

View File

@ -1,491 +0,0 @@
package proxy
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/stretchr/testify/assert"
)
func TestCalcDistanceTask_arrangeVectorsByStrID(t *testing.T) {
task := &calcDistanceTask{}
inputIds := make([]string, 0)
inputIds = append(inputIds, "c")
inputIds = append(inputIds, "b")
inputIds = append(inputIds, "a")
sequence := make(map[string]int)
sequence["a"] = 0
sequence["b"] = 1
sequence["c"] = 2
dim := 16
// float vector
floatValue := make([]float32, 0)
for i := 0; i < dim*3; i++ {
floatValue = append(floatValue, float32(i))
}
retrievedVectors := &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: floatValue,
},
},
}
result, err := task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
assert.NoError(t, err)
floatResult := result.GetFloatVector().GetData()
for i := 0; i < 3; i++ {
for j := 0; j < dim; j++ {
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
}
}
// binary vector
binaryValue := make([]byte, 0)
for i := 0; i < 3*dim/8; i++ {
binaryValue = append(binaryValue, byte(i))
}
retrievedVectors = &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
}
result, err = task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors)
assert.NoError(t, err)
binaryResult := result.GetBinaryVector()
numBytes := dim / 8
for i := 0; i < 3; i++ {
for j := 0; j < numBytes; j++ {
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
}
}
}
func TestCalcDistanceTask_arrangeVectorsByIntID(t *testing.T) {
task := &calcDistanceTask{}
inputIds := make([]int64, 0)
inputIds = append(inputIds, 2)
inputIds = append(inputIds, 0)
inputIds = append(inputIds, 1)
sequence := make(map[int64]int)
sequence[0] = 0
sequence[1] = 1
sequence[2] = 2
dim := 16
// float vector
floatValue := make([]float32, 0)
for i := 0; i < dim*3; i++ {
floatValue = append(floatValue, float32(i))
}
retrievedVectors := &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: floatValue,
},
},
}
result, err := task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
assert.NoError(t, err)
floatResult := result.GetFloatVector().GetData()
for i := 0; i < 3; i++ {
for j := 0; j < dim; j++ {
assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j])
}
}
// binary vector
binaryValue := make([]byte, 0)
for i := 0; i < dim*3; i++ {
binaryValue = append(binaryValue, byte(i))
}
retrievedVectors = &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
}
result, err = task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors)
assert.NoError(t, err)
binaryResult := result.GetBinaryVector()
numBytes := dim / 8
for i := 0; i < 3; i++ {
for j := 0; j < numBytes; j++ {
assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j])
}
}
}
func TestCalcDistanceTask_ExecuteFloat(t *testing.T) {
ctx := context.Background()
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
return nil, errors.New("unexpected error")
}
task := &calcDistanceTask{
traceID: "dummy",
queryFunc: queryFunc,
}
request := &milvuspb.CalcDistanceRequest{
OpLeft: nil,
OpRight: nil,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "L2"},
},
}
// left-op empty
calcResult, err := task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{},
},
},
OpRight: nil,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "L2"},
},
}
// left-op query error
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
fieldIds := make([]int64, 0)
fieldIds = append(fieldIds, 2)
fieldIds = append(fieldIds, 0)
fieldIds = append(fieldIds, 1)
dim := 8
floatValue := make([]float32, 0)
for i := 0; i < dim*3; i++ {
floatValue = append(floatValue, float32(i))
}
queryFunc = func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
if ids == nil {
return &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected",
},
}, nil
}
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "id",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
{
Type: schemapb.DataType_FloatVector,
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: floatValue,
},
},
},
},
},
},
}, nil
}
task.queryFunc = queryFunc
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
idArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{
FieldName: "vec",
IdArray: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
}
request = &milvuspb.CalcDistanceRequest{
OpLeft: idArray,
OpRight: idArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "L2"},
},
}
// success
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
// right-op query error
request.OpRight = nil
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request.OpRight = &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{
FieldName: "kkk",
IdArray: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
}
// right-op arrange error
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request.OpRight = &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: 5,
},
},
}
// different dimension
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request.OpRight = &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: make([]float32, 0),
},
},
},
},
}
// calcdistance return error
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
}
func TestCalcDistanceTask_ExecuteBinary(t *testing.T) {
ctx := context.Background()
fieldIds := make([]int64, 0)
fieldIds = append(fieldIds, 2)
fieldIds = append(fieldIds, 0)
fieldIds = append(fieldIds, 1)
dim := 16
binaryValue := make([]byte, 0)
for i := 0; i < 3*dim/8; i++ {
binaryValue = append(binaryValue, byte(i))
}
queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) {
if ids == nil {
return &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected",
},
}, nil
}
return &milvuspb.QueryResults{
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "id",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
{
Type: schemapb.DataType_FloatVector,
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
},
},
},
},
}, nil
}
idArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_IdArray{
IdArray: &milvuspb.VectorIDs{
FieldName: "vec",
IdArray: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: fieldIds,
},
},
},
},
},
}
request := &milvuspb.CalcDistanceRequest{
OpLeft: idArray,
OpRight: idArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
task := &calcDistanceTask{
traceID: "dummy",
queryFunc: queryFunc,
}
// success
calcResult, err := task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
floatArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_FloatVector{},
},
},
}
binaryArray := &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: binaryValue,
},
},
},
}
request = &milvuspb.CalcDistanceRequest{
OpLeft: floatArray,
OpRight: binaryArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
// float vs binary
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: binaryArray,
OpRight: binaryArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
// hamming
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: binaryArray,
OpRight: binaryArray,
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "TANIMOTO"},
},
}
// tanimoto
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode)
request = &milvuspb.CalcDistanceRequest{
OpLeft: binaryArray,
OpRight: &milvuspb.VectorsArray{
Array: &milvuspb.VectorsArray_DataArray{
DataArray: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: make([]byte, 0),
},
},
},
},
Params: []*commonpb.KeyValuePair{
{Key: "metric", Value: "HAMMING"},
},
}
// hamming error
calcResult, err = task.Execute(ctx, request)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode)
}

View File

@ -25,9 +25,9 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
@ -898,7 +898,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
}
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
if !distance.PositivelyRelated(metricType) {
if !metric.PositivelyRelated(metricType) {
for k := range ret.Results.Scores {
ret.Results.Scores[k] *= -1
}

View File

@ -37,8 +37,8 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -121,7 +121,7 @@ func getValidSearchParams() []*commonpb.KeyValuePair {
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
},
{
Key: SearchParamsKey,
@ -1441,7 +1441,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, test.offset)
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
assert.NoError(t, err)
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
@ -1481,7 +1481,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
for _, test := range lessThanLimitTests {
t.Run(test.description, func(t *testing.T) {
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, test.offset)
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset)
assert.NoError(t, err)
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
@ -1505,7 +1505,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
results = append(results, r)
}
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, 0)
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_Int64, 0)
assert.NoError(t, err)
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
@ -1532,7 +1532,7 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
results = append(results, r)
}
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_VarChar, 0)
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, metric.L2, schemapb.DataType_VarChar, 0)
assert.NoError(t, err)
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
@ -1717,7 +1717,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
noSearchParams := getBaseSearchParams()
noSearchParams = append(noSearchParams, &commonpb.KeyValuePair{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
})
offsetParam := getValidSearchParams()
@ -1775,7 +1775,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
spNoSearchParams := append(spNoMetricType, &commonpb.KeyValuePair{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
})
// no roundDecimal is valid

View File

@ -44,9 +44,9 @@ import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -383,7 +383,7 @@ func constructSearchRequest(
SearchParams: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
},
{
Key: SearchParamsKey,

View File

@ -1,283 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package distance
import (
"strings"
"sync"
"github.com/cockroachdb/errors"
)
const (
// L2 represents the Euclidean distance
L2 = "L2"
// IP represents the inner product distance
IP = "IP"
// COSINE represents the cosine distance
COSINE = "COSINE"
// HAMMING represents the hamming distance
HAMMING = "HAMMING"
// TANIMOTO represents the tanimoto distance
TANIMOTO = "TANIMOTO"
// JACCARD in string
JACCARD = "JACCARD"
// SUPERSTRUCTURE in string
SUPERSTRUCTURE = "SUPERSTRUCTURE"
// SUBSTRUCTURE in string
SUBSTRUCTURE = "SUBSTRUCTURE"
)
// ValidateMetricType returns metric text or error
func ValidateMetricType(metric string) (string, error) {
if metric == "" {
err := errors.New("metric type is empty")
return "", err
}
m := strings.ToUpper(metric)
if m == L2 || m == IP || m == HAMMING || m == TANIMOTO {
return m, nil
}
err := errors.New("invalid metric type")
return metric, err
}
// ValidateFloatArrayLength is used validate float vector length
func ValidateFloatArrayLength(dim int64, length int) error {
if length == 0 || int64(length)%dim != 0 {
err := errors.New("invalid float vector length")
return err
}
return nil
}
// CalcL2 returns the Euclidean distance of input vectors
func CalcL2(dim int64, left []float32, lIndex int64, right []float32, rIndex int64) float32 {
var sum float32
lFrom := lIndex * dim
rFrom := rIndex * dim
for i := int64(0); i < dim; i++ {
gap := left[lFrom+i] - right[rFrom+i]
sum += gap * gap
}
return sum
}
// CalcIP returns the inner product distance of input vectors
func CalcIP(dim int64, left []float32, lIndex int64, right []float32, rIndex int64) float32 {
var sum float32
lFrom := lIndex * dim
rFrom := rIndex * dim
for i := int64(0); i < dim; i++ {
sum += left[lFrom+i] * right[rFrom+i]
}
return sum
}
// CalcFFBatch calculate the distance of @left & @right vectors in batch by given @metic, store result in @result
func CalcFFBatch(dim int64, left []float32, lIndex int64, right []float32, metric string, result *[]float32) {
rightNum := int64(len(right)) / dim
for i := int64(0); i < rightNum; i++ {
var distance float32 = -1.0
if metric == L2 {
distance = CalcL2(dim, left, lIndex, right, i)
} else if metric == IP {
distance = CalcIP(dim, left, lIndex, right, i)
}
(*result)[lIndex*rightNum+i] = distance
}
}
// CalcFloatDistance calculate float distance by given metric
// it will checks input, and calculate the distance concurrently
func CalcFloatDistance(dim int64, left, right []float32, metric string) ([]float32, error) {
if dim <= 0 {
err := errors.New("invalid dimension")
return nil, err
}
metricUpper := strings.ToUpper(metric)
if metricUpper != L2 && metricUpper != IP {
err := errors.New("invalid metric type")
return nil, err
}
err := ValidateFloatArrayLength(dim, len(left))
if err != nil {
return nil, err
}
err = ValidateFloatArrayLength(dim, len(right))
if err != nil {
return nil, err
}
leftNum := int64(len(left)) / dim
rightNum := int64(len(right)) / dim
distArray := make([]float32, leftNum*rightNum)
// Multi-threads to calculate distance. TODO: avoid too many go routines
var waitGroup sync.WaitGroup
CalcWorker := func(index int64) {
CalcFFBatch(dim, left, index, right, metricUpper, &distArray)
waitGroup.Done()
}
for i := int64(0); i < leftNum; i++ {
waitGroup.Add(1)
go CalcWorker(i)
}
waitGroup.Wait()
return distArray, nil
}
////////////////////////////////////////////////////////////////////////////////
// SingleBitLen returns the bit length of @dim
func SingleBitLen(dim int64) int64 {
if dim%8 == 0 {
return dim
}
return dim + 8 - dim%8
}
// VectorCount counts bits by @dim & @length
func VectorCount(dim int64, length int) int64 {
singleBitLen := SingleBitLen(dim)
return int64(length*8) / singleBitLen
}
// ValidateBinaryArrayLength validates a binary array of @dim & @length
func ValidateBinaryArrayLength(dim int64, length int) error {
singleBitLen := SingleBitLen(dim)
totalBitLen := int64(length * 8)
if length == 0 || totalBitLen%singleBitLen != 0 {
err := errors.New("invalid binary vector length")
return err
}
return nil
}
// CountOne count 1 of uint8
// For 00000010, return 1
// Fro 11111111, return 8
func CountOne(n uint8) int32 {
count := int32(0)
for n != 0 {
count++
n = n & (n - 1)
}
return count
}
// CalcHamming calculate HAMMING distance
func CalcHamming(dim int64, left []byte, lIndex int64, right []byte, rIndex int64) int32 {
singleBitLen := SingleBitLen(dim)
numBytes := singleBitLen / 8
lFrom := lIndex * numBytes
rFrom := rIndex * numBytes
var hamming int32
for i := int64(0); i < numBytes; i++ {
var xor = left[lFrom+i] ^ right[rFrom+i]
// The dimension "dim" may not be an integer multiple of 8
// For example:
// dim = 11, each vector has 2 uint8 value
// the second uint8, only need to calculate 3 bits, the other 5 bits will be set to 0
if i == numBytes-1 && numBytes*8 > dim {
offset := numBytes*8 - dim
xor = xor & (255 << offset)
}
hamming += CountOne(xor)
}
return hamming
}
// CalcHammingBatch calculate HAMMING distance in batch, results are in @result
func CalcHammingBatch(dim int64, left []byte, lIndex int64, right []byte, result *[]int32) {
rightNum := VectorCount(dim, len(right))
for i := int64(0); i < rightNum; i++ {
hamming := CalcHamming(dim, left, lIndex, right, i)
(*result)[lIndex*rightNum+i] = hamming
}
}
func CalcHammingDistance(dim int64, left, right []byte) ([]int32, error) {
if dim <= 0 {
err := errors.New("invalid dimension")
return nil, err
}
err := ValidateBinaryArrayLength(dim, len(left))
if err != nil {
return nil, err
}
err = ValidateBinaryArrayLength(dim, len(right))
if err != nil {
return nil, err
}
leftNum := VectorCount(dim, len(left))
rightNum := VectorCount(dim, len(right))
distArray := make([]int32, leftNum*rightNum)
// Multi-threads to calculate distance. TODO: avoid too many go routines
var waitGroup sync.WaitGroup
CalcWorker := func(index int64) {
CalcHammingBatch(dim, left, index, right, &distArray)
waitGroup.Done()
}
for i := int64(0); i < leftNum; i++ {
waitGroup.Add(1)
go CalcWorker(i)
}
waitGroup.Wait()
return distArray, nil
}
func CalcTanimotoCoefficient(dim int64, hamming []int32) ([]float32, error) {
if dim <= 0 || len(hamming) == 0 {
err := errors.New("invalid input for tanimoto")
return nil, err
}
array := make([]float32, len(hamming))
for i := 0; i < len(hamming); i++ {
if hamming[i] > int32(dim) {
err := errors.New("invalid hamming for tanimoto")
return nil, err
}
equalBits := int32(dim) - hamming[i]
array[i] = float32(equalBits) / (float32(dim)*2 - float32(equalBits))
}
return array, nil
}

View File

@ -1,291 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package distance
import (
"math"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const PRECISION = 1e-6
func TestValidateMetricType(t *testing.T) {
invalidMetric := []string{"", "aaa"}
for _, str := range invalidMetric {
_, err := ValidateMetricType(str)
assert.Error(t, err)
}
validMetric := []string{"L2", "ip", "Hamming", "Tanimoto"}
for _, str := range validMetric {
metric, err := ValidateMetricType(str)
assert.NoError(t, err)
assert.True(t, metric == L2 || metric == IP || metric == HAMMING || metric == TANIMOTO)
}
}
func TestValidateFloatArrayLength(t *testing.T) {
err := ValidateFloatArrayLength(3, 12)
assert.NoError(t, err)
err = ValidateFloatArrayLength(5, 11)
assert.Error(t, err)
}
////////////////////////////////////////////////////////////////////////////////
func CreateFloatArray(n, dim int64) []float32 {
rand.Seed(time.Now().UnixNano())
num := n * dim
array := make([]float32, num)
for i := int64(0); i < num; i++ {
array[i] = rand.Float32()
}
return array
}
func DistanceL2(left, right []float32) float32 {
if len(left) != len(right) {
panic("array dimension not equal")
}
var sum float32
for i := 0; i < len(left); i++ {
gap := left[i] - right[i]
sum += gap * gap
}
return sum
}
func DistanceIP(left, right []float32) float32 {
if len(left) != len(right) {
panic("array dimension not equal")
}
var sum float32
for i := 0; i < len(left); i++ {
sum += left[i] * right[i]
}
return sum
}
func Test_CalcL2(t *testing.T) {
var dim int64 = 128
var leftNum int64 = 1
var rightNum int64 = 1
left := CreateFloatArray(leftNum, dim)
right := CreateFloatArray(rightNum, dim)
sum := DistanceL2(left, right)
distance := CalcL2(dim, left, 0, right, 0)
assert.Less(t, math.Abs(float64(sum-distance)), PRECISION)
distance = CalcL2(dim, left, 0, left, 0)
assert.Less(t, float64(distance), PRECISION)
}
func Test_CalcIP(t *testing.T) {
var dim int64 = 128
var leftNum int64 = 1
var rightNum int64 = 1
left := CreateFloatArray(leftNum, dim)
right := CreateFloatArray(rightNum, dim)
sum := DistanceIP(left, right)
distance := CalcIP(dim, left, 0, right, 0)
assert.Less(t, math.Abs(float64(sum-distance)), PRECISION)
}
func Test_CalcFloatDistance(t *testing.T) {
var dim int64 = 128
var leftNum int64 = 10
var rightNum int64 = 5
left := CreateFloatArray(leftNum, dim)
right := CreateFloatArray(rightNum, dim)
// Verify illegal cases
_, err := CalcFloatDistance(dim, left, right, "HAMMIN")
assert.Error(t, err)
_, err = CalcFloatDistance(3, left, right, "L2")
assert.Error(t, err)
_, err = CalcFloatDistance(dim, left, right, "HAMMIN")
assert.Error(t, err)
_, err = CalcFloatDistance(0, left, right, "L2")
assert.Error(t, err)
distances, err := CalcFloatDistance(dim, left, right, "L2")
assert.NoError(t, err)
// Verify the L2 distance algorithm is correct
invalid := CreateFloatArray(rightNum, 10)
_, err = CalcFloatDistance(dim, left, invalid, "L2")
assert.Error(t, err)
for i := int64(0); i < leftNum; i++ {
for j := int64(0); j < rightNum; j++ {
v1 := left[i*dim : (i+1)*dim]
v2 := right[j*dim : (j+1)*dim]
sum := DistanceL2(v1, v2)
assert.Less(t, math.Abs(float64(sum-distances[i*rightNum+j])), PRECISION)
}
}
// Verify the IP distance algorithm is correct
distances, err = CalcFloatDistance(dim, left, right, "IP")
assert.NoError(t, err)
for i := int64(0); i < leftNum; i++ {
for j := int64(0); j < rightNum; j++ {
v1 := left[i*dim : (i+1)*dim]
v2 := right[j*dim : (j+1)*dim]
sum := DistanceIP(v1, v2)
assert.Less(t, math.Abs(float64(sum-distances[i*rightNum+j])), PRECISION)
}
}
}
// //////////////////////////////////////////////////////////////////////////////
func CreateBinaryArray(n, dim int64) []byte {
rand.Seed(time.Now().UnixNano())
num := n * dim / 8
if num*8 < n*dim {
num = num + 1
}
array := make([]byte, num)
for i := int64(0); i < num; i++ {
n := rand.Intn(256)
array[i] = uint8(n)
}
return array
}
func Test_SingleBitLen(t *testing.T) {
n := SingleBitLen(125)
assert.Equal(t, n, int64(128))
n = SingleBitLen(133)
assert.Equal(t, n, int64(136))
}
func Test_VectorCount(t *testing.T) {
n := VectorCount(15, 20)
assert.Equal(t, n, int64(10))
n = VectorCount(8, 3)
assert.Equal(t, n, int64(3))
}
func Test_ValidateBinaryArrayLength(t *testing.T) {
err := ValidateBinaryArrayLength(21, 12)
assert.NoError(t, err)
err = ValidateBinaryArrayLength(21, 11)
assert.Error(t, err)
}
func Test_CountOne(t *testing.T) {
n := CountOne(6)
assert.Equal(t, n, int32(2))
n = CountOne(0)
assert.Equal(t, n, int32(0))
n = CountOne(255)
assert.Equal(t, n, int32(8))
}
func Test_CalcHamming(t *testing.T) {
var dim int64 = 22
// v1 = 00000010 00000110 00001000
v1 := make([]uint8, 3)
v1[0] = 2
v1[1] = 6
v1[2] = 8
// v2 = 00000001 00000111 00011011
v2 := make([]uint8, 3)
v2[0] = 1
v2[1] = 7
v2[2] = 27
n := CalcHamming(dim, v1, 0, v2, 0)
assert.Equal(t, n, int32(4))
}
func Test_CalcHamminDistance(t *testing.T) {
var dim int64 = 125
var leftNum int64 = 2
left := CreateBinaryArray(leftNum, dim)
_, e := CalcHammingDistance(0, left, left)
assert.Error(t, e)
distances, err := CalcHammingDistance(dim, left, left)
assert.NoError(t, err)
n := CalcHamming(dim, left, 0, left, 0)
assert.Equal(t, n, int32(0))
n = CalcHamming(dim, left, 1, left, 1)
assert.Equal(t, n, int32(0))
n = CalcHamming(dim, left, 0, left, 1)
assert.Equal(t, n, distances[1])
n = CalcHamming(dim, left, 1, left, 0)
assert.Equal(t, n, distances[2])
invalid := CreateBinaryArray(leftNum, 200)
_, e = CalcHammingDistance(dim, invalid, left)
assert.Error(t, e)
_, e = CalcHammingDistance(dim, left, invalid)
assert.Error(t, e)
}
func Test_CalcTanimotoCoefficient(t *testing.T) {
var dim int64 = 22
hamming := make([]int32, 2)
hamming[0] = 4
hamming[1] = 17
tanimoto, err := CalcTanimotoCoefficient(dim, hamming)
for i := 0; i < len(hamming); i++ {
realTanimoto := float64(int32(dim)-hamming[i]) / (float64(dim)*2.0 - float64(int32(dim)-hamming[i]))
assert.NoError(t, err)
assert.Less(t, math.Abs(float64(tanimoto[i])-realTanimoto), float64(PRECISION))
}
_, err = CalcTanimotoCoefficient(-1, hamming)
assert.Error(t, err)
_, err = CalcTanimotoCoefficient(3, hamming)
assert.Error(t, err)
}

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -12,10 +13,10 @@ import (
func Test_baseChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
Metric: metric.L2,
}
paramsWithoutDim := map[string]string{
Metric: L2,
Metric: metric.L2,
}
cases := []struct {
params map[string]string

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -12,44 +13,44 @@ import (
func Test_binFlatChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
Metric: metric.JACCARD,
}
paramsWithoutDim := map[string]string{
Metric: JACCARD,
Metric: metric.JACCARD,
}
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -15,24 +16,24 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
Metric: metric.JACCARD,
}
paramsWithoutDim := map[string]string{
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
Metric: metric.JACCARD,
}
invalidParams := copyParams(validParams)
invalidParams[Metric] = L2
invalidParams[Metric] = metric.L2
paramsWithLargeNlist := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MaxNList + 1),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
Metric: metric.JACCARD,
}
paramsWithSmallNlist := map[string]string{
@ -40,26 +41,26 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(MinNList - 1),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
Metric: metric.L2,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: IP,
Metric: metric.IP,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: COSINE,
Metric: metric.COSINE,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
@ -67,21 +68,21 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: HAMMING,
Metric: metric.HAMMING,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
Metric: metric.JACCARD,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
@ -89,14 +90,14 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
NLIST: strconv.Itoa(100),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),

View File

@ -1,32 +1,11 @@
package indexparamcheck
import "github.com/milvus-io/milvus/pkg/common"
import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric"
)
const (
// L2 represents Euclidean distance
L2 = "L2"
// IP represents inner product distance
IP = "IP"
// COSINE represents cosine distance
COSINE = "COSINE"
// HAMMING represents hamming distance
HAMMING = "HAMMING"
// JACCARD represents jaccard distance
JACCARD = "JACCARD"
// TANIMOTO represents tanimoto distance
TANIMOTO = "TANIMOTO"
// SUBSTRUCTURE represents substructure distance
SUBSTRUCTURE = "SUBSTRUCTURE"
// SUPERSTRUCTURE represents superstructure distance
SUPERSTRUCTURE = "SUPERSTRUCTURE"
MinNBits = 1
MaxNBits = 16
DefaultNBits = 8
@ -65,16 +44,17 @@ const (
)
// METRICS is a set of all metrics types supported for float vector.
var METRICS = []string{L2, IP, COSINE} // const
var METRICS = []string{metric.L2, metric.IP, metric.COSINE} // const
// BinIDMapMetrics is a set of all metric types supported for binary vector.
var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERSTRUCTURE} // const
var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const
var HnswMetrics = []string{L2, IP, COSINE, HAMMING, JACCARD} // const
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
var BinIDMapMetrics = []string{metric.HAMMING, metric.JACCARD, metric.TANIMOTO, metric.SUBSTRUCTURE,
metric.SUPERSTRUCTURE} // const
var BinIvfMetrics = []string{metric.HAMMING, metric.JACCARD, metric.TANIMOTO} // const
var HnswMetrics = []string{metric.L2, metric.IP, metric.COSINE, metric.HAMMING, metric.JACCARD} // const
var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
const (
FloatVectorDefaultMetricType = IP
BinaryVectorDefaultMetricType = JACCARD
FloatVectorDefaultMetricType = metric.IP
BinaryVectorDefaultMetricType = metric.JACCARD
)

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -12,13 +13,13 @@ import (
func Test_diskannChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
Metric: metric.L2,
}
validParamsBigDim := copyParams(validParams)
validParamsBigDim[DIM] = strconv.Itoa(2048)
invalidParamsWithoutDim := map[string]string{
Metric: L2,
Metric: metric.L2,
}
invalidParamsSmallDim := copyParams(validParams)
@ -26,36 +27,36 @@ func Test_diskannChecker_CheckTrain(t *testing.T) {
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {

View File

@ -4,6 +4,8 @@ import (
"strconv"
"testing"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -11,36 +13,36 @@ func Test_flatChecker_CheckTrain(t *testing.T) {
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -15,7 +16,7 @@ func Test_hnswChecker_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: L2,
Metric: metric.L2,
}
invalidEfParamsMin := copyParams(validParams)
@ -34,50 +35,50 @@ func Test_hnswChecker_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
HNSWM: strconv.Itoa(16),
EFConstruction: strconv.Itoa(200),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {

View File

@ -13,6 +13,8 @@ package indexparamcheck
import (
"strconv"
"github.com/milvus-io/milvus/pkg/util/metric"
)
// TODO: add more test cases which `IndexChecker.CheckTrain` return false,
@ -22,7 +24,7 @@ func invalidIVFParamsMin() map[string]string {
invalidIVFParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MinNList - 1),
Metric: L2,
Metric: metric.L2,
}
return invalidIVFParams
}
@ -31,7 +33,7 @@ func invalidIVFParamsMax() map[string]string {
invalidIVFParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(MaxNList + 1),
Metric: L2,
Metric: metric.L2,
}
return invalidIVFParams
}

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -13,49 +14,49 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: L2,
Metric: metric.L2,
}
p1 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -15,7 +16,7 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
paramsNotMultiplier := map[string]string{
@ -23,21 +24,21 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(5),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
validParamsWithoutNbits := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
Metric: L2,
Metric: metric.L2,
}
validParamsWithoutDim := map[string]string{
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
invalidParamsDim := copyParams(validParams)
@ -50,7 +51,7 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
invalidParamsIVF := copyParams(validParams)
@ -67,21 +68,21 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
@ -89,35 +90,35 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -15,7 +16,7 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
if withNBits {
validParams[NBITS] = strconv.Itoa(DefaultNBits)
@ -31,50 +32,50 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(100),
NBITS: strconv.Itoa(8),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/stretchr/testify/assert"
)
@ -16,21 +17,21 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
validParamsWithoutNbits := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
Metric: L2,
Metric: metric.L2,
}
validParamsWithoutDim := map[string]string{
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
invalidParamsDim := copyParams(validParams)
@ -43,7 +44,7 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
invalidParamsIVF := copyParams(validParams)
@ -60,21 +61,21 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: IP,
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: COSINE,
Metric: metric.COSINE,
}
p4 := map[string]string{
@ -82,35 +83,35 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: HAMMING,
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: JACCARD,
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: TANIMOTO,
Metric: metric.TANIMOTO,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUBSTRUCTURE,
Metric: metric.SUBSTRUCTURE,
}
p8 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: SUPERSTRUCTURE,
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {

View File

@ -0,0 +1,42 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package metric
// MetricType string.
type MetricType = string
// MetricType definitions
const (
// L2 represents Euclidean distance
L2 MetricType = "L2"
// IP represents inner product distance
IP MetricType = "IP"
// COSINE represents cosine distance
COSINE MetricType = "COSINE"
// HAMMING represents hamming distance
HAMMING MetricType = "HAMMING"
// JACCARD represents jaccard distance
JACCARD MetricType = "JACCARD"
// TANIMOTO represents tanimoto distance
TANIMOTO MetricType = "TANIMOTO"
// SUBSTRUCTURE represents substructure distance
SUBSTRUCTURE MetricType = "SUBSTRUCTURE"
// SUPERSTRUCTURE represents superstructure distance
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
)

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package distance
package metric
import "strings"

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package distance
package metric
import "testing"

View File

@ -36,8 +36,8 @@ import (
"github.com/milvus-io/milvus/internal/util/importutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
@ -164,7 +164,7 @@ func (s *BulkInsertSuite) TestBulkInsert() {
CollectionName: collectionName,
FieldName: "embeddings",
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexHNSW, distance.L2),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexHNSW, metric.L2),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
@ -192,9 +192,9 @@ func (s *BulkInsertSuite) TestBulkInsert() {
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(integration.IndexHNSW, distance.L2)
params := integration.GetSearchParams(integration.IndexHNSW, metric.L2)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
"embeddings", schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
"embeddings", schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, err := c.Proxy.Search(ctx, searchReq)

View File

@ -29,8 +29,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus/tests/integration"
)
@ -258,7 +258,7 @@ func (s *TestGetVectorSuite) TestGetVector_FLAT() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexFaissIDMap
s.metricType = distance.L2
s.metricType = metric.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
@ -269,7 +269,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexFaissIvfFlat
s.metricType = distance.L2
s.metricType = metric.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
@ -280,7 +280,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_PQ() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexFaissIvfPQ
s.metricType = distance.L2
s.metricType = metric.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = true
@ -291,7 +291,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexFaissIvfSQ8
s.metricType = distance.L2
s.metricType = metric.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = true
@ -302,7 +302,7 @@ func (s *TestGetVectorSuite) TestGetVector_HNSW() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexHNSW
s.metricType = distance.L2
s.metricType = metric.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
@ -313,7 +313,7 @@ func (s *TestGetVectorSuite) TestGetVector_IP() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexHNSW
s.metricType = distance.IP
s.metricType = metric.IP
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
@ -324,7 +324,7 @@ func (s *TestGetVectorSuite) TestGetVector_StringPK() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexHNSW
s.metricType = distance.L2
s.metricType = metric.L2
s.pkType = schemapb.DataType_VarChar
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
@ -335,7 +335,7 @@ func (s *TestGetVectorSuite) TestGetVector_BinaryVector() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexFaissBinIvfFlat
s.metricType = distance.JACCARD
s.metricType = metric.JACCARD
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_BinaryVector
s.searchFailed = false
@ -347,7 +347,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
s.nq = 10000
s.topK = 200
s.indexType = integration.IndexHNSW
s.metricType = distance.L2
s.metricType = metric.L2
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false

View File

@ -30,8 +30,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
@ -110,7 +110,7 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
@ -138,9 +138,9 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, err := c.Proxy.Search(ctx, searchReq)

View File

@ -12,8 +12,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
@ -78,7 +78,7 @@ func (s *GetIndexStatisticsSuite) TestGetIndexStatistics() {
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))

View File

@ -28,9 +28,9 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
@ -77,7 +77,7 @@ func (s *InsertSuite) TestInsert() {
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.IP),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)

View File

@ -26,7 +26,6 @@ import (
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/tests/integration"
"github.com/stretchr/testify/suite"
@ -36,6 +35,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"go.uber.org/zap"
)
@ -776,7 +776,7 @@ func (s *JSONExprSuite) insertFlushIndexLoad(ctx context.Context, dbName, collec
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
},
{
Key: common.IndexTypeKey,
@ -824,9 +824,9 @@ func (s *JSONExprSuite) doSearch(collectionName string, outputField []string, ex
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, distance.L2, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, err := s.Cluster.Proxy.Search(context.Background(), searchReq)
@ -899,9 +899,9 @@ func (s *JSONExprSuite) doSearchWithInvalidExpr(collectionName string, outputFie
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, distance.L2, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, outputField, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, err := s.Cluster.Proxy.Search(context.Background(), searchReq)

View File

@ -32,8 +32,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
)
type MetaWatcherSuite struct {
@ -271,7 +271,7 @@ func (s *MetaWatcherSuite) TestShowReplicas() {
},
{
Key: common.MetricTypeKey,
Value: distance.L2,
Value: metric.L2,
},
{
Key: common.IndexTypeKey,

View File

@ -29,9 +29,9 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
)
@ -109,7 +109,7 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.IP),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
@ -137,12 +137,12 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
radius := 10
filter := 20
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.IP)
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
// only pass in radius when range search
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(ctx, searchReq)
@ -155,7 +155,7 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
// pass in radius and range_filter when range search
params["range_filter"] = filter
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.Proxy.Search(ctx, searchReq)
@ -169,7 +169,7 @@ func (s *RangeSearchSuite) TestRangeSearchIP() {
params["radius"] = filter
params["range_filter"] = radius
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.Proxy.Search(ctx, searchReq)
@ -257,7 +257,7 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
@ -285,11 +285,11 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
radius := 20
filter := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, distance.L2)
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
// only pass in radius when range search
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(ctx, searchReq)
@ -302,7 +302,7 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
// pass in radius and range_filter when range search
params["range_filter"] = filter
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.Proxy.Search(ctx, searchReq)
@ -316,7 +316,7 @@ func (s *RangeSearchSuite) TestRangeSearchL2() {
params["radius"] = filter
params["range_filter"] = radius
searchReq = integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.L2, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
searchResult, _ = c.Proxy.Search(ctx, searchReq)

View File

@ -27,7 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
"github.com/stretchr/testify/suite"
@ -125,7 +125,7 @@ func (s *RefreshConfigSuite) TestRefreshDefaultIndexName() {
_, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.L2),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
})
s.NoError(err)

View File

@ -26,9 +26,9 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/distance"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/tests/integration"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
@ -109,7 +109,7 @@ func (s *UpsertSuite) TestUpsert() {
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, distance.IP),
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
@ -138,7 +138,7 @@ func (s *UpsertSuite) TestUpsert() {
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, "")
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, distance.IP, params, nq, dim, topk, roundDecimal)
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(ctx, searchReq)

View File

@ -2667,6 +2667,7 @@ class TestCollectionSearch(TestcaseBase):
assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip("tanimoto obsolete")
@pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"])
def test_search_binary_tanimoto_flat_index(self, nq, dim, auto_id, _async, index, is_flush):
"""
@ -2706,6 +2707,7 @@ class TestCollectionSearch(TestcaseBase):
assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip("substructure obsolete")
@pytest.mark.parametrize("index", ["BIN_FLAT"])
def test_search_binary_substructure_flat_index(self, auto_id, _async, index, is_flush):
"""
@ -2742,6 +2744,7 @@ class TestCollectionSearch(TestcaseBase):
assert len(res) <= default_limit
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip("superstructure obsolete")
@pytest.mark.parametrize("index", ["BIN_FLAT"])
def test_search_binary_superstructure_flat_index(self, auto_id, _async, index, is_flush):
"""
@ -6656,6 +6659,7 @@ class TestCollectionRangeSearch(TestcaseBase):
"limit": 0})
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip("tanimoto obsolete")
@pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"])
def test_range_search_binary_tanimoto_flat_index(self, dim, auto_id, _async, index, is_flush):
"""
@ -6711,6 +6715,7 @@ class TestCollectionRangeSearch(TestcaseBase):
assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip("tanimoto obsolete")
@pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"])
def test_range_search_binary_tanimoto_invalid_params(self, index):
"""

View File

@ -676,7 +676,7 @@ class TestUtilityBase(TestcaseBase):
def metric(self, request):
yield request.param
@pytest.fixture(scope="function", params=["HAMMING", "TANIMOTO"])
@pytest.fixture(scope="function", params=["HAMMING", "JACCARD"])
def metric_binary(self, request):
yield request.param