mirror of https://github.com/milvus-io/milvus.git
Fix flat index can be created with invalid metric type (#24180)
Signed-off-by: longjiquan <jiquan.long@zilliz.com>pull/24229/head
parent
008285f849
commit
a98c79b6a6
|
@ -45,7 +45,7 @@ func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, erro
|
|||
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
|
||||
mgr.checkers[IndexRaftIvfFlat] = newIVFBaseChecker()
|
||||
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
|
||||
mgr.checkers[IndexFaissIDMap] = newBaseChecker()
|
||||
mgr.checkers[IndexFaissIDMap] = newFlatChecker()
|
||||
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
|
||||
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()
|
||||
mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker()
|
||||
|
|
|
@ -32,7 +32,7 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) {
|
|||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*baseChecker)
|
||||
_, ok = adapter.(*flatChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
||||
|
@ -86,7 +86,7 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) {
|
|||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, nil, adapter)
|
||||
_, ok = adapter.(*baseChecker)
|
||||
_, ok = adapter.(*flatChecker)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
package indexparamcheck
|
||||
|
||||
type flatChecker struct {
|
||||
floatVectorBaseChecker
|
||||
}
|
||||
|
||||
func newFlatChecker() IndexChecker {
|
||||
return &flatChecker{}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package indexparamcheck
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_flatChecker_CheckTrain(t *testing.T) {
|
||||
|
||||
p1 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: L2,
|
||||
}
|
||||
p2 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: IP,
|
||||
}
|
||||
p3 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: COSINE,
|
||||
}
|
||||
|
||||
p4 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: HAMMING,
|
||||
}
|
||||
p5 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: JACCARD,
|
||||
}
|
||||
p6 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: TANIMOTO,
|
||||
}
|
||||
p7 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUBSTRUCTURE,
|
||||
}
|
||||
p8 := map[string]string{
|
||||
DIM: strconv.Itoa(128),
|
||||
Metric: SUPERSTRUCTURE,
|
||||
}
|
||||
cases := []struct {
|
||||
params map[string]string
|
||||
errIsNil bool
|
||||
}{
|
||||
{p1, true},
|
||||
{p2, true},
|
||||
{p3, true},
|
||||
{p4, false},
|
||||
{p5, false},
|
||||
{p6, false},
|
||||
{p7, false},
|
||||
{p8, false},
|
||||
}
|
||||
|
||||
c := newFlatChecker()
|
||||
for _, test := range cases {
|
||||
err := c.CheckTrain(test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue