mirror of https://github.com/milvus-io/milvus.git
Enable dimension check in Proxy when create index request received (#16718)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>pull/16757/head
parent
bb9ccbb7e2
commit
2fe8677cbf
|
@ -0,0 +1,43 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type getCollectionIDFunc func(ctx context.Context, collectionName string) (typeutil.UniqueID, error)
|
||||
type getCollectionSchemaFunc func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
|
||||
|
||||
type mockCache struct {
|
||||
Cache
|
||||
getIDFunc getCollectionIDFunc
|
||||
getSchemaFunc getCollectionSchemaFunc
|
||||
}
|
||||
|
||||
func (m *mockCache) GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
if m.getIDFunc != nil {
|
||||
return m.getIDFunc(ctx, collectionName)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
if m.getSchemaFunc != nil {
|
||||
return m.getSchemaFunc(ctx, collectionName)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCache) setGetIDFunc(f getCollectionIDFunc) {
|
||||
m.getIDFunc = f
|
||||
}
|
||||
|
||||
func (m *mockCache) setGetSchemaFunc(f getCollectionSchemaFunc) {
|
||||
m.getSchemaFunc = f
|
||||
}
|
||||
|
||||
func newMockCache() *mockCache {
|
||||
return &mockCache{}
|
||||
}
|
|
@ -1795,36 +1795,13 @@ func (cit *createIndexTask) OnEnqueue() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (cit *createIndexTask) PreExecute(ctx context.Context) error {
|
||||
cit.Base.MsgType = commonpb.MsgType_CreateIndex
|
||||
cit.Base.SourceID = Params.ProxyCfg.GetNodeID()
|
||||
|
||||
collName, fieldName := cit.CollectionName, cit.FieldName
|
||||
|
||||
col, err := globalMetaCache.GetCollectionInfo(ctx, collName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cit.collectionID = col.collID
|
||||
|
||||
if err := validateCollectionName(collName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldName(fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check index param, not accurate, only some static rules
|
||||
func parseIndexParams(m []*commonpb.KeyValuePair) (map[string]string, error) {
|
||||
indexParams := make(map[string]string)
|
||||
for _, kv := range cit.CreateIndexRequest.ExtraParams {
|
||||
for _, kv := range m {
|
||||
if kv.Key == "params" { // TODO(dragondriver): change `params` to const variable
|
||||
params, err := funcutil.ParseIndexParamsMap(kv.Value)
|
||||
if err != nil {
|
||||
log.Warn("Failed to parse index params",
|
||||
zap.String("params", kv.Value),
|
||||
zap.Error(err))
|
||||
continue
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range params {
|
||||
indexParams[k] = v
|
||||
|
@ -1833,23 +1810,68 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
|
|||
indexParams[kv.Key] = kv.Value
|
||||
}
|
||||
}
|
||||
indexType, exist := indexParams["index_type"] // TODO(dragondriver): change `index_type` to const variable
|
||||
_, exist := indexParams["index_type"] // TODO(dragondriver): change `index_type` to const variable
|
||||
if !exist {
|
||||
indexType = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type
|
||||
indexParams["index_type"] = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type
|
||||
}
|
||||
return indexParams, nil
|
||||
}
|
||||
|
||||
//TODO:: add default index type for VarChar type field
|
||||
func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) {
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.GetCollectionName())
|
||||
if err != nil {
|
||||
log.Error("failed to get collection schema", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to get collection schema: %s", err)
|
||||
}
|
||||
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
||||
if err != nil {
|
||||
log.Error("failed to parse collection schema", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to parse collection schema: %s", err)
|
||||
}
|
||||
field, err := schemaHelper.GetFieldFromName(cit.GetFieldName())
|
||||
if err != nil {
|
||||
log.Error("create index on non-exist field", zap.Error(err))
|
||||
return nil, fmt.Errorf("cannot create index on non-exist field: %s", cit.GetFieldName())
|
||||
}
|
||||
return field, nil
|
||||
}
|
||||
|
||||
func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
|
||||
vecDataTypes := []schemapb.DataType{
|
||||
schemapb.DataType_FloatVector,
|
||||
schemapb.DataType_BinaryVector,
|
||||
}
|
||||
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
|
||||
return nil
|
||||
}
|
||||
params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams()))
|
||||
params = append(params, field.GetTypeParams()...)
|
||||
params = append(params, field.GetIndexParams()...)
|
||||
dimensionInSchema, err := funcutil.GetAttrByKeyFromRepeatedKV("dim", params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dimension not found in schema")
|
||||
}
|
||||
dimension, exist := indexParams["dim"]
|
||||
if exist {
|
||||
if dimensionInSchema != dimension {
|
||||
return fmt.Errorf("dimension mismatch, dimension in schema: %s, dimension: %s", dimensionInSchema, dimension)
|
||||
}
|
||||
} else {
|
||||
indexParams["dim"] = dimensionInSchema
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) error {
|
||||
indexType := indexParams["index_type"]
|
||||
|
||||
// skip params check of non-vector field.
|
||||
vecDataTypes := []schemapb.DataType{
|
||||
schemapb.DataType_FloatVector,
|
||||
schemapb.DataType_BinaryVector,
|
||||
}
|
||||
|
||||
for _, f := range col.schema.GetFields() {
|
||||
if f.GetName() == fieldName && !funcutil.SliceContain(vecDataTypes, f.GetDataType()) {
|
||||
return indexparamcheck.CheckIndexValid(f.GetDataType(), indexType, indexParams)
|
||||
}
|
||||
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
|
||||
return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams)
|
||||
}
|
||||
|
||||
adapter, err := indexparamcheck.GetConfAdapterMgrInstance().GetAdapter(indexType)
|
||||
|
@ -1858,15 +1880,46 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
|
|||
return fmt.Errorf("invalid index type: %s", indexType)
|
||||
}
|
||||
|
||||
if err := fillDimension(field, indexParams); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ok := adapter.CheckTrain(indexParams)
|
||||
if !ok {
|
||||
log.Warn("Create index with invalid params", zap.Any("index_params", indexParams))
|
||||
return fmt.Errorf("invalid index params: %v", cit.CreateIndexRequest.ExtraParams)
|
||||
return fmt.Errorf("invalid index params: %v", indexParams)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cit *createIndexTask) PreExecute(ctx context.Context) error {
|
||||
cit.Base.MsgType = commonpb.MsgType_CreateIndex
|
||||
cit.Base.SourceID = Params.ProxyCfg.GetNodeID()
|
||||
|
||||
collName := cit.CollectionName
|
||||
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, collName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cit.collectionID = collID
|
||||
|
||||
field, err := cit.getIndexedField(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check index param, not accurate, only some static rules
|
||||
indexParams, err := parseIndexParams(cit.GetExtraParams())
|
||||
if err != nil {
|
||||
log.Error("failed to parse index params", zap.Error(err))
|
||||
return fmt.Errorf("failed to parse index params: %s", err)
|
||||
}
|
||||
|
||||
return checkTrain(field, indexParams)
|
||||
}
|
||||
|
||||
func (cit *createIndexTask) Execute(ctx context.Context) error {
|
||||
var err error
|
||||
cit.result, err = cit.rootCoord.CreateIndex(ctx, cit.CreateIndexRequest)
|
||||
|
|
|
@ -21,11 +21,14 @@ import (
|
|||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
|
@ -2169,3 +2172,265 @@ func TestAlterAlias_all(t *testing.T) {
|
|||
assert.NoError(t, task.Execute(ctx))
|
||||
assert.NoError(t, task.PostExecute(ctx))
|
||||
}
|
||||
|
||||
func Test_createIndexTask_getIndexedField(t *testing.T) {
|
||||
collectionName := "test"
|
||||
fieldName := "test"
|
||||
|
||||
cit := &createIndexTask{
|
||||
CreateIndexRequest: &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: fieldName,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
cache := newMockCache()
|
||||
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
return &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: fieldName,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: nil,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
AutoID: false,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
field, err := cit.getIndexedField(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fieldName, field.GetName())
|
||||
})
|
||||
|
||||
t.Run("schema not found", func(t *testing.T) {
|
||||
cache := newMockCache()
|
||||
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
return nil, errors.New("mock")
|
||||
})
|
||||
globalMetaCache = cache
|
||||
_, err := cit.getIndexedField(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid schema", func(t *testing.T) {
|
||||
cache := newMockCache()
|
||||
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
return &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: fieldName,
|
||||
},
|
||||
{
|
||||
Name: fieldName, // duplicate
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
_, err := cit.getIndexedField(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("field not found", func(t *testing.T) {
|
||||
cache := newMockCache()
|
||||
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
return &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: fieldName + fieldName,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
_, err := cit.getIndexedField(context.Background())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_fillDimension(t *testing.T) {
|
||||
t.Run("scalar", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}
|
||||
assert.NoError(t, fillDimension(f, nil))
|
||||
})
|
||||
|
||||
t.Run("no dim in schema", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
}
|
||||
assert.Error(t, fillDimension(f, nil))
|
||||
})
|
||||
|
||||
t.Run("dimension mismatch", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Error(t, fillDimension(f, map[string]string{"dim": "8"}))
|
||||
})
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
}
|
||||
m := map[string]string{}
|
||||
assert.NoError(t, fillDimension(f, m))
|
||||
assert.Equal(t, "128", m["dim"])
|
||||
})
|
||||
}
|
||||
|
||||
func Test_checkTrain(t *testing.T) {
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
}
|
||||
m := map[string]string{
|
||||
"index_type": "IVF_FLAT",
|
||||
"nlist": "1024",
|
||||
"metric_type": "L2",
|
||||
}
|
||||
assert.NoError(t, checkTrain(f, m))
|
||||
})
|
||||
|
||||
t.Run("scalar", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}
|
||||
m := map[string]string{
|
||||
"index_type": "scalar",
|
||||
}
|
||||
assert.NoError(t, checkTrain(f, m))
|
||||
})
|
||||
|
||||
t.Run("dimension mismatch", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
}
|
||||
m := map[string]string{
|
||||
"index_type": "IVF_FLAT",
|
||||
"nlist": "1024",
|
||||
"metric_type": "L2",
|
||||
"dim": "8",
|
||||
}
|
||||
assert.Error(t, checkTrain(f, m))
|
||||
})
|
||||
|
||||
t.Run("invalid params", func(t *testing.T) {
|
||||
f := &schemapb.FieldSchema{
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
}
|
||||
m := map[string]string{
|
||||
"index_type": "IVF_FLAT",
|
||||
"metric_type": "L2",
|
||||
}
|
||||
assert.Error(t, checkTrain(f, m))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_createIndexTask_PreExecute(t *testing.T) {
|
||||
collectionName := "test"
|
||||
fieldName := "test"
|
||||
|
||||
cit := &createIndexTask{
|
||||
CreateIndexRequest: &milvuspb.CreateIndexRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_CreateIndex,
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
FieldName: fieldName,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return 100, nil
|
||||
})
|
||||
cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
return &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: fieldName,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: nil,
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
AutoID: false,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
globalMetaCache = cache
|
||||
cit.CreateIndexRequest.ExtraParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "index_type",
|
||||
Value: "IVF_FLAT",
|
||||
},
|
||||
{
|
||||
Key: "nlist",
|
||||
Value: "1024",
|
||||
},
|
||||
{
|
||||
Key: "metric_type",
|
||||
Value: "L2",
|
||||
},
|
||||
}
|
||||
assert.NoError(t, cit.PreExecute(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("collection not found", func(t *testing.T) {
|
||||
cache := newMockCache()
|
||||
cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return 0, errors.New("mock")
|
||||
})
|
||||
globalMetaCache = cache
|
||||
assert.Error(t, cit.PreExecute(context.Background()))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -131,10 +131,9 @@ type BaseConfAdapter struct {
|
|||
|
||||
// CheckTrain check whether the params contains supported metrics types
|
||||
func (adapter *BaseConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
// dimension is specified when create collection
|
||||
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
// return false
|
||||
//}
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
|
||||
return CheckStrByValues(params, Metric, METRICS)
|
||||
}
|
||||
|
@ -179,8 +178,8 @@ func (adapter *IVFPQConfAdapter) CheckTrain(params map[string]string) bool {
|
|||
|
||||
func (adapter *IVFPQConfAdapter) checkPQParams(params map[string]string) bool {
|
||||
dimStr, dimensionExist := params[DIM]
|
||||
if !dimensionExist { // dimension is specified when creating collection
|
||||
return true
|
||||
if !dimensionExist {
|
||||
return false
|
||||
}
|
||||
|
||||
dimension, err := strconv.Atoi(dimStr)
|
||||
|
@ -260,10 +259,9 @@ type BinIDMAPConfAdapter struct {
|
|||
|
||||
// CheckTrain checks if a binary flat index can be built with the specific parameters.
|
||||
func (adapter *BinIDMAPConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
// dimension is specified when create collection
|
||||
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
// return false
|
||||
//}
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
|
||||
return CheckStrByValues(params, Metric, BinIDMapMetrics)
|
||||
}
|
||||
|
@ -278,10 +276,9 @@ type BinIVFConfAdapter struct {
|
|||
|
||||
// CheckTrain checks if a binary ivf index can be built with specific parameters.
|
||||
func (adapter *BinIVFConfAdapter) CheckTrain(params map[string]string) bool {
|
||||
// dimension is specified when create collection
|
||||
//if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
// return false
|
||||
//}
|
||||
if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||
return false
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
@ -50,11 +51,15 @@ func TestBaseConfAdapter_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: L2,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
}
|
||||
|
||||
adapter := newBaseConfAdapter()
|
||||
|
@ -141,7 +146,7 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
|
|||
{validParamsWithoutNbits, true},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{validParamsWithoutDim, true},
|
||||
{validParamsWithoutDim, false},
|
||||
{invalidParamsDim, false},
|
||||
{invalidParamsNbits, false},
|
||||
{invalidParamsWithoutIVF, false},
|
||||
|
@ -150,8 +155,9 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
|
|||
}
|
||||
|
||||
adapter := newIVFPQConfAdapter()
|
||||
for _, test := range cases {
|
||||
for i, test := range cases {
|
||||
if got := adapter.CheckTrain(test.params); got != test.want {
|
||||
fmt.Printf("i: %d, params: %v\n", i, test.params)
|
||||
t.Errorf("IVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
|
||||
}
|
||||
}
|
||||
|
@ -187,11 +193,15 @@ func TestBinIDMAPConfAdapter_CheckTrain(t *testing.T) {
|
|||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
Metric: JACCARD,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
}
|
||||
|
||||
adapter := newBinIDMAPConfAdapter()
|
||||
|
@ -211,6 +221,12 @@ func TestBinIVFConfAdapter_CheckTrain(t *testing.T) {
|
|||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
paramsWithoutDim := map[string]string{
|
||||
NLIST: strconv.Itoa(100),
|
||||
IVFM: strconv.Itoa(4),
|
||||
NBITS: strconv.Itoa(8),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
|
||||
invalidParams := copyParams(validParams)
|
||||
invalidParams[Metric] = L2
|
||||
|
@ -220,6 +236,7 @@ func TestBinIVFConfAdapter_CheckTrain(t *testing.T) {
|
|||
want bool
|
||||
}{
|
||||
{validParams, true},
|
||||
{paramsWithoutDim, false},
|
||||
{invalidIVFParamsMin(), false},
|
||||
{invalidIVFParamsMax(), false},
|
||||
{invalidParams, false},
|
||||
|
|
Loading…
Reference in New Issue