Let RAFT_IVF_PQ param accept m=0 (#23134)

Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
pull/23159/head
Cai Yudong 2023-03-31 11:22:22 +08:00 committed by GitHub
parent 7612dd1fc3
commit 7612c75c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 126 additions and 1 deletions

View File

@ -211,6 +211,60 @@ func newIVFPQConfAdapter() *IVFPQConfAdapter {
return &IVFPQConfAdapter{}
}
// RaftIVFPQConfAdapter checks if a RAFT_IVF_PQ index can be built.
type RaftIVFPQConfAdapter struct {
IVFConfAdapter
}
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
func (adapter *RaftIVFPQConfAdapter) CheckTrain(params map[string]string) bool {
if !adapter.IVFConfAdapter.CheckTrain(params) {
return false
}
return adapter.checkPQParams(params)
}
func (adapter *RaftIVFPQConfAdapter) checkPQParams(params map[string]string) bool {
dimStr, dimensionExist := params[DIM]
if !dimensionExist {
return false
}
dimension, err := strconv.Atoi(dimStr)
if err != nil { // invalid dimension
return false
}
// nbits can be set to default: 8
nbitsStr, nbitsExist := params[NBITS]
if nbitsExist {
_, err := strconv.Atoi(nbitsStr)
if err != nil { // invalid nbits
return false
}
}
mStr, ok := params[IVFM]
if !ok {
return false
}
m, err := strconv.Atoi(mStr)
if err != nil { // invalid m
return false
}
// here is the only difference with IVF_PQ
if m == 0 {
return true
}
return dimension%m == 0
}
func newRaftIVFPQConfAdapter() *RaftIVFPQConfAdapter {
return &RaftIVFPQConfAdapter{}
}
// IVFSQConfAdapter checks if a IVF_SQ index can be built.
type IVFSQConfAdapter struct {
IVFConfAdapter

View File

@ -47,7 +47,7 @@ func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error)
func (mgr *ConfAdapterMgrImpl) registerConfAdapter() {
mgr.adapters[IndexRaftIvfFlat] = newIVFConfAdapter()
mgr.adapters[IndexRaftIvfPQ] = newIVFPQConfAdapter()
mgr.adapters[IndexRaftIvfPQ] = newRaftIVFPQConfAdapter()
mgr.adapters[IndexFaissIDMap] = newBaseConfAdapter()
mgr.adapters[IndexFaissIvfFlat] = newIVFConfAdapter()
mgr.adapters[IndexFaissIvfPQ] = newIVFPQConfAdapter()

View File

@ -165,6 +165,77 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) {
}
}
func TestRaftIVFPQConfAdapter_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
validParamsWithoutNbits := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
Metric: L2,
}
validParamsWithoutDim := map[string]string{
NLIST: strconv.Itoa(1024),
IVFM: strconv.Itoa(4),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsDim := copyParams(validParams)
invalidParamsDim[DIM] = "NAN"
invalidParamsNbits := copyParams(validParams)
invalidParamsNbits[NBITS] = "NAN"
invalidParamsWithoutIVF := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
NBITS: strconv.Itoa(8),
Metric: L2,
}
invalidParamsIVF := copyParams(validParams)
invalidParamsIVF[IVFM] = "NAN"
invalidParamsM := copyParams(validParams)
invalidParamsM[DIM] = strconv.Itoa(65536)
validParamsMzero := copyParams(validParams)
validParamsMzero[IVFM] = "0"
cases := []struct {
params map[string]string
want bool
}{
{validParams, true},
{validParamsWithoutNbits, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{validParamsWithoutDim, false},
{invalidParamsDim, false},
{invalidParamsNbits, false},
{invalidParamsWithoutIVF, false},
{invalidParamsIVF, false},
{invalidParamsM, false},
{validParamsMzero, true},
}
adapter := newRaftIVFPQConfAdapter()
for i, test := range cases {
if got := adapter.CheckTrain(test.params); got != test.want {
t.Log("i:", i, "params", test.params)
t.Errorf("RaftIVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want)
}
}
}
func TestIVFSQConfAdapter_CheckTrain(t *testing.T) {
getValidParams := func(withNBits bool) map[string]string {
validParams := map[string]string{