diff --git a/internal/util/indexparamcheck/conf_adapter.go b/internal/util/indexparamcheck/conf_adapter.go index 5a6feb1215..8a2b6f46b5 100644 --- a/internal/util/indexparamcheck/conf_adapter.go +++ b/internal/util/indexparamcheck/conf_adapter.go @@ -211,6 +211,60 @@ func newIVFPQConfAdapter() *IVFPQConfAdapter { return &IVFPQConfAdapter{} } +// RaftIVFPQConfAdapter checks if a RAFT_IVF_PQ index can be built. +type RaftIVFPQConfAdapter struct { + IVFConfAdapter +} + +// CheckTrain checks if ivf-pq index can be built with the specific index parameters. +func (adapter *RaftIVFPQConfAdapter) CheckTrain(params map[string]string) bool { + if !adapter.IVFConfAdapter.CheckTrain(params) { + return false + } + + return adapter.checkPQParams(params) +} + +func (adapter *RaftIVFPQConfAdapter) checkPQParams(params map[string]string) bool { + dimStr, dimensionExist := params[DIM] + if !dimensionExist { + return false + } + + dimension, err := strconv.Atoi(dimStr) + if err != nil { // invalid dimension + return false + } + + // nbits can be set to default: 8 + nbitsStr, nbitsExist := params[NBITS] + if nbitsExist { + _, err := strconv.Atoi(nbitsStr) + if err != nil { // invalid nbits + return false + } + } + + mStr, ok := params[IVFM] + if !ok { + return false + } + m, err := strconv.Atoi(mStr) + if err != nil { // invalid m + return false + } + + // here is the only difference with IVF_PQ + if m == 0 { + return true + } + return dimension%m == 0 +} + +func newRaftIVFPQConfAdapter() *RaftIVFPQConfAdapter { + return &RaftIVFPQConfAdapter{} +} + // IVFSQConfAdapter checks if a IVF_SQ index can be built. type IVFSQConfAdapter struct { IVFConfAdapter diff --git a/internal/util/indexparamcheck/conf_adapter_mgr.go b/internal/util/indexparamcheck/conf_adapter_mgr.go index 04d85d9e02..831e3c8eda 100644 --- a/internal/util/indexparamcheck/conf_adapter_mgr.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr.go @@ -47,7 +47,7 @@ func (mgr *ConfAdapterMgrImpl) GetAdapter(indexType string) (ConfAdapter, error) func (mgr *ConfAdapterMgrImpl) registerConfAdapter() { mgr.adapters[IndexRaftIvfFlat] = newIVFConfAdapter() - mgr.adapters[IndexRaftIvfPQ] = newIVFPQConfAdapter() + mgr.adapters[IndexRaftIvfPQ] = newRaftIVFPQConfAdapter() mgr.adapters[IndexFaissIDMap] = newBaseConfAdapter() mgr.adapters[IndexFaissIvfFlat] = newIVFConfAdapter() mgr.adapters[IndexFaissIvfPQ] = newIVFPQConfAdapter() diff --git a/internal/util/indexparamcheck/conf_adapter_test.go b/internal/util/indexparamcheck/conf_adapter_test.go index 8e1328eaa3..7fc6204a08 100644 --- a/internal/util/indexparamcheck/conf_adapter_test.go +++ b/internal/util/indexparamcheck/conf_adapter_test.go @@ -165,6 +165,77 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) { } } +func TestRaftIVFPQConfAdapter_CheckTrain(t *testing.T) { + validParams := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + IVFM: strconv.Itoa(4), + NBITS: strconv.Itoa(8), + Metric: L2, + } + + validParamsWithoutNbits := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + IVFM: strconv.Itoa(4), + Metric: L2, + } + + validParamsWithoutDim := map[string]string{ + NLIST: strconv.Itoa(1024), + IVFM: strconv.Itoa(4), + NBITS: strconv.Itoa(8), + Metric: L2, + } + + invalidParamsDim := copyParams(validParams) + invalidParamsDim[DIM] = "NAN" + + invalidParamsNbits := copyParams(validParams) + invalidParamsNbits[NBITS] = "NAN" + + invalidParamsWithoutIVF := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + NBITS: strconv.Itoa(8), + Metric: L2, + } + + invalidParamsIVF := copyParams(validParams) + invalidParamsIVF[IVFM] = "NAN" + + invalidParamsM := copyParams(validParams) + invalidParamsM[DIM] = strconv.Itoa(65536) + + validParamsMzero := copyParams(validParams) + validParamsMzero[IVFM] = "0" + + cases := []struct { + params map[string]string + want bool + }{ + {validParams, true}, + {validParamsWithoutNbits, true}, + {invalidIVFParamsMin(), false}, + {invalidIVFParamsMax(), false}, + {validParamsWithoutDim, false}, + {invalidParamsDim, false}, + {invalidParamsNbits, false}, + {invalidParamsWithoutIVF, false}, + {invalidParamsIVF, false}, + {invalidParamsM, false}, + {validParamsMzero, true}, + } + + adapter := newRaftIVFPQConfAdapter() + for i, test := range cases { + if got := adapter.CheckTrain(test.params); got != test.want { + t.Log("i:", i, "params", test.params) + t.Errorf("RaftIVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want) + } + } +} + func TestIVFSQConfAdapter_CheckTrain(t *testing.T) { getValidParams := func(withNBits bool) map[string]string { validParams := map[string]string{